Kernels
dongseokmotif's picture
feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
e8e2c81 unverified
import logging
import types
from collections import defaultdict
from typing import Any
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.profiler import record_function
from .adamw import _placement_cache, _tensor_cache, step_adamw
from .async_utils import run_pipeline
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
get_default_muon_param_groups, is_expert_param, update_p)
from .cpu_offload import CPUOffloadPool
from .distributed.utils import (_is_shard, construct_shard_mesh,
get_slices_of_dtensor)
from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
_zeropower_via_newtonschulz5,
zeropower_via_newtonschulz5,
zeropower_via_newtonschulz5_batched)
from .pipeline import muon_chunk_pipeline, prelaunch_first_gather
from .qk_clip import compute_scales, get_qk_clip_info, qk_clip
logger = logging.getLogger(__name__)
def _expand_expert_params(names, params, expert_keys):
"""Expand expert params by splitting on dim 0 (expert dimension).
Params whose name matches any key in ``expert_keys`` are treated as
expert-parallel tensors. Their outermost dimension is the expert
dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D
``nn.Parameter`` views so that in-place updates propagate back to
the original storage.
Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` —
if they are expert params, their key must be added to ``expert_keys``.
The grad must already be set on each expert param (e.g. after momentum).
For DTensor expert params, placements that shard on dim 0 (expert dim)
are consumed by the split. Non-dim-0 shard placements (e.g. TP) are
preserved: each 2D slice is wrapped as a DTensor on the corresponding
submesh so the parallel pipeline handles the TP communication.
"""
expanded_names = []
expanded_params = []
for n, p in zip(names, params):
is_expert = is_expert_param(n, expert_keys)
is_dtensor = isinstance(p.data, DTensor)
if is_expert:
if is_dtensor:
logger.debug(
"[expand_expert] %s: expert DTensor, shape=%s, "
"placements=%s, mesh=%s, local_shape=%s", n, p.shape,
p.placements, p.device_mesh.mesh_dim_names,
p.to_local().shape)
else:
logger.debug(
"[expand_expert] %s: expert plain tensor, shape=%s", n,
p.data.shape)
if not is_expert:
assert p.data.ndim <= 2, (
f"Param {n} has ndim={p.data.ndim} but does not match "
f"expert_keys={expert_keys}. If this is an expert param, "
f"add its key to expert_keys.")
expanded_names.append(n)
expanded_params.append(p)
continue
g = p.grad
assert g is not None, (
f"Expert param {n} must have grad set before expansion")
tp_mesh = None
tp_placements_2d = None
if is_dtensor:
local_data = p.to_local()
local_grad = g.to_local() if isinstance(g, DTensor) else g
# Find non-dim-0 shard placements (e.g. TP sharding).
# After splitting on dim 0, Shard(k) becomes Shard(k-1).
tp_dim_indices = []
tp_placements_2d = []
for i, pl in enumerate(p.placements):
if _is_shard(pl) and pl.dim != 0:
tp_dim_indices.append(i)
tp_placements_2d.append(Shard(pl.dim - 1))
if tp_dim_indices:
tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i]
for i in tp_dim_indices)
if len(tp_dim_names) == 1:
tp_mesh = p.device_mesh[tp_dim_names[0]]
else:
tp_mesh = p.device_mesh[tp_dim_names]
else:
local_data = p.data
local_grad = g
# Expand: split dim 0, reshape each slice to 2D.
num_local_experts = local_data.shape[0]
for i in range(num_local_experts):
slice_data = local_data[i]
slice_grad = local_grad[i]
if tp_mesh is not None:
# Wrap as DTensor on TP submesh so the pipeline handles
# TP communication (gather/scatter across TP ranks).
dt_data = DTensor.from_local(slice_data,
device_mesh=tp_mesh,
placements=tp_placements_2d)
dt_grad = DTensor.from_local(slice_grad,
device_mesh=tp_mesh,
placements=tp_placements_2d)
expert_param = torch.nn.Parameter(dt_data, requires_grad=False)
expert_param.grad = dt_grad
else:
expert_param = torch.nn.Parameter(slice_data,
requires_grad=False)
expert_param.grad = slice_grad
expanded_names.append(f"{n}[{i}]")
expanded_params.append(expert_param)
p.grad = None # allow expert grad storage to be freed after pipeline
return expanded_names, expanded_params
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- We believe this optimizer is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.
Arguments:
model: The model to be optimized by Muon.
is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
momentum: The momentum used by the internal SGD. (0.95 is a good default)
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
weight_decay: The weight decay for Muon and AdamW.
Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead.
adamw_lr: The learning rate for the internal AdamW.
adamw_betas: The betas for the internal AdamW.
adamw_eps: The epsilon for the internal AdamW.
none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
debug: Whether to print debug information.
clip_info : Configuration for QK clipping. Expected keys:
- "q_indices" (list[int]): Indices of query heads to consider.
- "k_indices" (list[int]): Indices of key heads to consider.
- "head_dim" (int): Dimensionality of each attention head.
- "threshold" (float): Threshold value; heads whose QK logits exceed
this value will be scaled down.
Default is:
{
"q_indices": [],
"k_indices": [],
"head_dim": 128,
"threshold": 100
}
warmup_step : How many all2all gather, compute operations are launched in advance
before the corresponding all2all scatter steps begin.
A higher warmup_step increases memory usage but can improve
performance by overlapping communication.
Parallel muon only.
chunk_size : Batch size of parameters to process in each
all2all gather/compute/scatter step.
Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
use_distributed_muon: Use distributed muon by Liu et al. (2024).
For testing purpose only.
expert_keys: List of strings to identify expert-parallel parameters.
If any key appears in a parameter's name, its outermost
dimension is treated as the expert dimension and expanded
into per-expert 2D params for Muon. For example,
``expert_keys=["experts"]`` matches any param whose name
contains "experts". 3D+ params not matched by any key
will raise an error.
"""
def __init__(self,
params,
lr=1e-3,
momentum=0.95,
nesterov=True,
ns_steps=5,
weight_decay=0.1,
adamw_betas=(0.9, 0.95),
adamw_eps=1e-8,
none_grad=True,
debug=False,
clip_config=None,
warmup_step=5,
chunk_size=-1,
use_distributed_muon=False,
expert_keys=None):
defaults = dict(
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
none_grad=none_grad,
use_muon=True,
)
error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
if isinstance(params, types.GeneratorType):
raise ValueError(error_message.format(idx=0) + instruction_code)
for _idx, param_group in enumerate(params):
if param_group.get("use_muon", None) is None:
raise ValueError(
error_message.format(idx=_idx) + instruction_code)
super().__init__(params, defaults)
self.debug = debug
self.clip_config = clip_config if clip_config is not None else {
"q_indices": [],
"k_indices": [],
"head_dim": 128,
"threshold": 100,
}
self.warmup_step = warmup_step
self.chunk_size = chunk_size
self.use_distributed_muon = use_distributed_muon
self.expert_keys = expert_keys
self.cpu_offload = False
self._cpu_offload_pool: CPUOffloadPool | None = None
self._offload_initialized = False
self._parallel_cache: dict[tuple[str, ...], dict] = {}
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
def _calc_flops(self, G, steps):
assert len(G.shape) == 2
M, N = G.shape
if M > N:
M, N = N, M
return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
def get_shard_mesh(self, p):
"""
Get the shard mesh for a parameter p on the given rank.
"""
assert isinstance(
p, DTensor), "Parallel Muon only supports DTensor parameters."
shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
p.placements, p.device_mesh)
return shard_mesh, shard_pg, shard_placements
def init_state_and_assign_params(self, names, params, group, qk_logits):
param_to_state = {}
param_to_flops = {}
total_flops = 0
for p in params:
g = p.grad
if g is None:
continue
assert g.ndim == 2, "Muon only supports 2D parameters."
flops = self._calc_flops(g, group["ns_steps"])
param_to_flops[id(p)] = flops
total_flops += flops
if self.debug:
logger.debug("Total TFLOPs for Muon: %.2f TFLOPs",
total_flops / 1e12)
paired = list(zip(names, params))
paired_sorted = sorted(paired,
key=lambda x: param_to_flops[id(x[1])],
reverse=True)
names_sorted, params_sorted = zip(*paired_sorted)
ordered_names = list(names_sorted)
ordered_params = list(params_sorted)
round_robin = 0
mesh = ordered_params[0].device_mesh
placements = ordered_params[0].placements
shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
ordered_params[0])
shard_mesh_flattened = shard_mesh.mesh.flatten()
num_ranks = dist.get_world_size(group=shard_pg)
for n, p in zip(ordered_names, ordered_params):
if mesh != p.device_mesh:
raise ValueError("All parameters must be on the same mesh.")
if placements != p.placements:
raise ValueError("All parameters must have same placements.")
worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
round_robin = (round_robin + 1) % len(shard_mesh_flattened)
qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
# Precompute per-rank indices and numels for all-to-all.
rank_indices: dict[int, tuple] = {}
rank_numels: dict[int, int] = {}
for r in range(num_ranks):
indices = get_slices_of_dtensor(p, r, shard_mesh,
shard_placements)
rank_indices[r] = indices
numel = 1
for idx, dim_size in zip(indices, p.shape):
if isinstance(idx, slice):
start, stop, step = idx.indices(dim_size)
numel *= max(0, (stop - start + (step - 1)) // step)
else:
numel *= len(idx)
rank_numels[r] = numel
param_to_state[id(p)] = _muon_state(
worker_rank=worker_rank,
process_group=shard_pg,
rank_indices=rank_indices,
rank_numels=rank_numels,
name=n,
qk_clip_state=qk_clip_state,
)
return param_to_state, ordered_params
def base(self, names, params, group, lr, weight_decay, qk_logits):
# Momentum is already applied by _step_muon before this method.
for n, p in zip(names, params):
g = p.grad
if g is None:
continue
u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
steps=group["ns_steps"])
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
update_p(p, u, lr, adjusted_lr, weight_decay)
qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits)
scales_full = compute_scales(
p, qk_clip_state) if qk_clip_state is not None else None
if scales_full is not None:
qk_clip(p, scales_full, qk_clip_state)
def distributed_muon(
self,
names: list[str],
params: list[torch.nn.Parameter],
group: dict[str, Any],
lr: float,
weight_decay: float,
qk_logits: list[torch.Tensor | DTensor] | None,
):
"""Batched Distributed Muon — for testing/correctness verification only.
Uses all-gather to reconstruct full tensors, computes Newton-Schulz on
the full grad, then slices back to local shards. This is simpler but
slower than the parallel pipeline (all2all) path, so it serves as a
reference implementation for verifying correctness.
"""
with record_function("distributed_muon"):
# Momentum is already applied by _step_muon before this method.
ns_steps = group["ns_steps"]
# Separate plain tensors (no communication) from DTensors.
plain_names, plain_params = [], []
dtensor_names, dtensor_params = [], []
for n, p in zip(names, params):
if p.grad is None:
continue
if isinstance(p.data, DTensor):
dtensor_names.append(n)
dtensor_params.append(p)
else:
plain_names.append(n)
plain_params.append(p)
# Process plain tensors per-param (no communication).
for n, p in zip(plain_names, plain_params):
u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE),
steps=ns_steps)
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
update_p(p, u, lr, adjusted_lr, weight_decay)
qk_clip_state = get_qk_clip_info(self.clip_config, n,
qk_logits)
scales_full = compute_scales(
p, qk_clip_state) if qk_clip_state is not None else None
if scales_full is not None:
qk_clip(p, scales_full, qk_clip_state)
if not dtensor_params:
return
# Group DTensors by (placements, mesh) for batched all-gather.
placement_groups: dict[tuple,
tuple[list,
list]] = defaultdict(lambda: ([], []))
for n, p in zip(dtensor_names, dtensor_params):
key = (p.placements, p.device_mesh)
placement_groups[key][0].append(n)
placement_groups[key][1].append(p)
logger.info(
"distributed_muon: %d placement groups, %d total dtensors",
len(placement_groups), len(dtensor_params))
for (placements, mesh), (grp_names,
grp_params) in placement_groups.items():
shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
placements, mesh)
rank = dist.get_rank(shard_pg)
world_size = dist.get_world_size(shard_pg)
logger.info(" group: %d params, placements=%s, world_size=%d",
len(grp_params), placements, world_size)
# Separate params that can be batched (all shard dims evenly
# divisible) from those needing per-param full_tensor
# (e.g. MoE gate weights with fewer rows than shard ranks).
# all_gather_into_tensor requires equal buffer sizes across
# ranks, so uneven splits must use DTensor full_tensor().
batch_names, batch_params = [], []
single_names, single_params = [], []
for n, p in zip(grp_names, grp_params):
even = all(p.shape[pl.dim] %
shard_mesh.mesh.shape[dim_idx] == 0
for dim_idx, pl in enumerate(shard_placements))
if even:
batch_names.append(n)
batch_params.append(p)
else:
single_names.append(n)
single_params.append(p)
# Process uneven-split params per-param via full_tensor().
for n, p in zip(single_names, single_params):
with record_function("distributed_muon::newton_schulz"):
g_full = p.grad.full_tensor().to(COMM_DTYPE)
u_full = _zeropower_via_newtonschulz5(g_full,
steps=ns_steps)
del g_full
with record_function("distributed_muon::update"):
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
p._local_tensor.mul_(1 - lr * weight_decay)
local_indices = get_slices_of_dtensor(
p, rank, shard_mesh, shard_placements)
u_local = u_full[local_indices]
p._local_tensor.add_(u_local, alpha=-adjusted_lr)
del u_full
qk_clip_state = get_qk_clip_info(
self.clip_config, n, qk_logits)
scales_full = compute_scales(
p, qk_clip_state
) if qk_clip_state is not None else None
if scales_full is not None:
ratio = p.shape[0] // scales_full.shape[0]
idx0 = local_indices[0]
if isinstance(idx0, slice):
start = idx0.start or 0
idx0 = torch.arange(start,
idx0.stop,
device=scales_full.device)
row_scales = scales_full[idx0 // ratio]
p._local_tensor.mul_(row_scales.view(-1, 1))
if not batch_params:
continue
logger.info(" batched=%d, single=%d", len(batch_params),
len(single_params))
# Concat all local grad shards into a single flat buffer.
with record_function("distributed_muon::gather"):
grad_locals = [
p.grad.to_local().to(COMM_DTYPE).flatten()
for p in batch_params
]
numels = [g.numel() for g in grad_locals]
grad_concat = torch.cat(grad_locals)
del grad_locals
# Single all-gather (replaces N separate full_tensor).
grad_gathered = torch.empty(
grad_concat.numel() * world_size,
dtype=COMM_DTYPE,
device="cuda",
)
dist.all_gather_into_tensor(grad_gathered,
grad_concat,
group=shard_pg)
total_numel = grad_concat.numel()
del grad_concat
# Precompute per-param offsets within the concat buffer.
offsets = []
off = 0
for ne in numels:
offsets.append(off)
off += ne
# Per-param: reconstruct full grad → NS → local update.
for i, (n, p) in enumerate(zip(batch_names, batch_params)):
with record_function("distributed_muon::newton_schulz"):
g_full = torch.empty(p.shape,
dtype=COMM_DTYPE,
device="cuda")
for r in range(world_size):
r_start = r * total_numel + offsets[i]
shard = grad_gathered[r_start:r_start + numels[i]]
indices = get_slices_of_dtensor(
p, r, shard_mesh, shard_placements)
g_full[indices] = shard.reshape(
g_full[indices].shape)
u_full = _zeropower_via_newtonschulz5(g_full,
steps=ns_steps)
del g_full
with record_function("distributed_muon::update"):
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
p._local_tensor.mul_(1 - lr * weight_decay)
local_indices = get_slices_of_dtensor(
p, rank, shard_mesh, shard_placements)
u_local = u_full[local_indices]
p._local_tensor.add_(u_local, alpha=-adjusted_lr)
del u_full
qk_clip_state = get_qk_clip_info(
self.clip_config, n, qk_logits)
scales_full = compute_scales(
p, qk_clip_state
) if qk_clip_state is not None else None
if scales_full is not None:
ratio = p.shape[0] // scales_full.shape[0]
idx0 = local_indices[0]
if isinstance(idx0, slice):
start = idx0.start or 0
idx0 = torch.arange(start,
idx0.stop,
device=scales_full.device)
row_scales = scales_full[idx0 // ratio]
p._local_tensor.mul_(row_scales.view(-1, 1))
def _setup_parallel(self, names, params, group, qk_logits):
"""Compute (or retrieve cached) parallel pipeline metadata.
Returns:
(ordered_params, param_to_state, rank, chunk_size)
"""
cache_key = tuple(names)
if cache_key not in self._parallel_cache:
# First call: compute metadata and populate cache.
param_to_state, ordered_params = self.init_state_and_assign_params(
names, params, group, qk_logits)
shard_pg = param_to_state[id(ordered_params[0])].process_group
rank = dist.get_rank(group=shard_pg)
if self.chunk_size == -1:
shard_ranks = dist.get_world_size(shard_pg)
chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
elif self.chunk_size > 0:
chunk_size = self.chunk_size
else:
raise ValueError(
"chunk_size must be -1 or a positive integer.")
ordered_names = [
param_to_state[id(p)].name for p in ordered_params
]
name_to_state = {
param_to_state[id(p)].name: param_to_state[id(p)]
for p in ordered_params
}
self._parallel_cache[cache_key] = {
'ordered_names': ordered_names,
'name_to_state': name_to_state,
'rank': rank,
'chunk_size': chunk_size,
}
else:
# Cached path: rebuild param_to_state with current id(p) keys.
cache = self._parallel_cache[cache_key]
rank = cache['rank']
chunk_size = cache['chunk_size']
name_to_param = dict(zip(names, params))
ordered_params = [name_to_param[n] for n in cache['ordered_names']]
param_to_state = {}
for p, n in zip(ordered_params, cache['ordered_names']):
cached_state = cache['name_to_state'][n]
param_to_state[id(p)] = _muon_state(
worker_rank=cached_state.worker_rank,
process_group=cached_state.process_group,
rank_indices=cached_state.rank_indices,
rank_numels=cached_state.rank_numels,
name=n,
qk_clip_state=get_qk_clip_info(self.clip_config, n,
qk_logits),
)
return ordered_params, param_to_state, rank, chunk_size
def parallel(self,
names,
params,
group,
lr,
weight_decay,
qk_logits,
prelaunch_gather=None):
"""
Perform a parallel optimization step using Muon.
Parameters are chunked and each chunk is processed by a
:func:`muon_chunk_pipeline` generator. :func:`run_pipeline`
interleaves multiple chunks so that communication and computation
overlap across chunks (the same overlap previously achieved by the
warmup + main-loop index scheduling).
If ``prelaunch_gather`` is provided, it is passed to the first
chunk's generator to skip re-launching the already in-flight
A2A gather.
"""
# Momentum is already applied by _step_muon before this method.
ordered_params, param_to_state, rank, chunk_size = (
self._setup_parallel(names, params, group, qk_logits))
def pipelines():
first = True
for start in range(0, len(ordered_params), chunk_size):
chunk = ordered_params[start:start + chunk_size]
if chunk:
kwargs = dict(
params=chunk,
param_to_state=param_to_state,
rank=rank,
ns_steps=group["ns_steps"],
lr=lr,
weight_decay=weight_decay,
none_grad=group["none_grad"],
)
if first and prelaunch_gather is not None:
kwargs['prelaunch_gather'] = prelaunch_gather
first = False
yield muon_chunk_pipeline(**kwargs)
with record_function("muon::pipeline"):
run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1)
def _step_muon(self, group, qk_logits=None):
params = group["params"]
lr = group["lr"]
weight_decay = group["weight_decay"]
momentum = group["momentum"]
names = group["names"]
# Apply momentum to all params before routing/expansion.
# Batched using _foreach_* ops (compiled, fullgraph=True).
with record_function("muon::momentum"):
active_params = [p for p in params if p.grad is not None]
if active_params:
# Ensure momentum buffers exist (avoid zeros_like when already present).
for p in active_params:
if "momentum_buffer" not in self.state[p]:
self.state[p]["momentum_buffer"] = torch.zeros_like(
p.grad)
# Extract local tensors for compiled batch function.
local_grads = [
p.grad._local_tensor
if isinstance(p.grad, DTensor) else p.grad
for p in active_params
]
local_bufs = [
self.state[p]["momentum_buffer"]._local_tensor
if isinstance(self.state[p]["momentum_buffer"], DTensor)
else self.state[p]["momentum_buffer"]
for p in active_params
]
# Wrap momentum as tensor for torch.compile.
batch_pre_ortho(local_grads, local_bufs,
torch.tensor(momentum), group["nesterov"])
# For non-nesterov, the result is the momentum buffer.
if not group["nesterov"]:
for p in active_params:
p.grad = self.state[p]["momentum_buffer"]
# Identify batched experts for deferred NS.
# Detection is cheap (condition checks only); actual NS compute is
# deferred so it can overlap with the first chunk's A2A gather.
deferred_expert_work = []
if self.expert_keys:
batched_expert_indices = []
for i, (n, p) in enumerate(zip(names, params)):
if not (is_expert_param(n, self.expert_keys)
and p.grad is not None):
continue
# Eligible: plain tensor, or DTensor with no non-dim-0 shards.
if isinstance(p.data, DTensor):
has_tp = any(
_is_shard(pl) and pl.dim != 0 for pl in p.placements)
if has_tp:
continue
batched_expert_indices.append(i)
if batched_expert_indices:
# Save refs for deferred NS; free grads from param list.
for i in batched_expert_indices:
p = params[i]
g = p.grad
local_g = (g._local_tensor
if isinstance(g, DTensor) else g)
local_data = (p.data._local_tensor if isinstance(
p.data, DTensor) else p.data)
deferred_expert_work.append((local_data, local_g))
p.grad = None
# Remove batched experts from lists before expansion.
keep = sorted(
set(range(len(params))) - set(batched_expert_indices))
names = [names[i] for i in keep]
params = [params[i] for i in keep]
def _run_deferred_expert_ns():
"""Execute deferred batched expert NS."""
if not deferred_expert_work:
return
with record_function("muon::batched_expert_ns"):
ns_steps = group["ns_steps"]
for local_data, local_g in deferred_expert_work:
u = zeropower_via_newtonschulz5_batched(
local_g.to(COMM_DTYPE), steps=ns_steps)
adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:])
local_data.mul_(1 - lr * weight_decay)
local_data.add_(u, alpha=-adjusted_lr)
# Expand expert params by splitting on dim 0.
logger.debug("[_step_muon] before expand: %d params, expert_keys=%s",
len(params), self.expert_keys)
if self.expert_keys:
cache_key = tuple(id(p) for p in params)
cache = self._expert_expand_cache.get(cache_key)
if cache is None:
# Cold path: full expansion + build cache metadata.
exp_names, exp_params = _expand_expert_params(
names, params, self.expert_keys)
# Build per-expert-group info for hot-path grad updates.
grad_info = []
exp_idx = 0
for orig_idx, (n, p) in enumerate(zip(names, params)):
if not is_expert_param(n, self.expert_keys):
exp_idx += 1
continue
is_dt = isinstance(p.data, DTensor)
num_experts = (p.to_local() if is_dt else p.data).shape[0]
# Detect TP mesh from the first expanded expert param.
tp_mesh = None
tp_pls = None
sample = exp_params[exp_idx]
if isinstance(sample.data, DTensor):
tp_mesh = sample.data.device_mesh
tp_pls = list(sample.data.placements)
grad_info.append((orig_idx, num_experts, exp_idx, is_dt,
tp_mesh, tp_pls))
exp_idx += num_experts
self._expert_expand_cache[cache_key] = {
'names': exp_names,
'params': exp_params,
'grad_info': grad_info,
}
names, params = exp_names, exp_params
else:
# Hot path: reuse cached params, only update expert grads.
for (orig_idx, num_experts, exp_start, is_dt, tp_mesh,
tp_pls) in cache['grad_info']:
p = params[orig_idx]
g = p.grad
local_grad = (g.to_local()
if is_dt and isinstance(g, DTensor) else g)
for i in range(num_experts):
expert_p = cache['params'][exp_start + i]
sg = local_grad[i]
if tp_mesh is not None:
expert_p.grad = DTensor.from_local(
sg, device_mesh=tp_mesh, placements=tp_pls)
else:
expert_p.grad = sg
p.grad = None
names = cache['names']
params = cache['params']
else:
names, params = _expand_expert_params(names, params,
self.expert_keys)
logger.debug("[_step_muon] after expand: %d params", len(params))
param_dtensors = []
name_dtensors = []
param_tensors = []
name_tensors = []
# distributed_muon is a reference implementation for testing only.
# The parallel pipeline (all2all) path below is the production path.
if self.use_distributed_muon:
_run_deferred_expert_ns()
self.distributed_muon(names=names,
params=params,
group=group,
lr=lr,
weight_decay=weight_decay,
qk_logits=qk_logits)
return
for n, p in zip(names, params):
if p is None or p.grad is None:
continue
if isinstance(p.data, DTensor):
if all(
isinstance(placement, Replicate)
for placement in p.placements):
logger.debug(
"[route] %s → base (DTensor all-Replicate), "
"shape=%s, placements=%s", n, p.shape, p.placements)
param_tensors.append(p)
name_tensors.append(n)
else:
logger.debug(
"[route] %s → parallel (DTensor), shape=%s, "
"placements=%s, mesh=%s", n, p.shape, p.placements,
p.device_mesh.mesh_dim_names)
param_dtensors.append(p)
name_dtensors.append(n)
elif isinstance(p.data, torch.Tensor):
logger.debug("[route] %s → base (plain tensor), shape=%s", n,
p.data.shape)
param_tensors.append(p)
name_tensors.append(n)
else:
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, "
f"{len(param_tensors)} Tensors → base")
def group_dtensors(dtensors, names):
# To support different placements, we group parameters by placements
# and run parallel Muon on each group.
placement_to_params = defaultdict(lambda: ([], []))
assert len(dtensors) == len(names)
for p, n in zip(dtensors, names):
placement_to_params[tuple([p.placements,
p.device_mesh])][0].append(n)
placement_to_params[tuple([p.placements,
p.device_mesh])][1].append(p)
return placement_to_params
if len(param_dtensors) > 0:
if not dist.is_initialized():
raise RuntimeError(
"Parallel Muon requires torch.distributed to be initialized."
)
dtensor_group = group_dtensors(param_dtensors, name_dtensors)
# Pre-launch the first chunk's A2A gather so that the NCCL
# communication overlaps with the (deferred) batched expert NS
# compute on the default CUDA stream.
prelaunch = None
if deferred_expert_work:
first_names, first_params = next(iter(dtensor_group.values()))
ordered, pts, rnk, csz = self._setup_parallel(
first_names, first_params, group, qk_logits)
first_chunk = ordered[:csz]
if first_chunk:
prelaunch = prelaunch_first_gather(first_chunk, pts, rnk,
group["none_grad"])
_run_deferred_expert_ns()
first_group = True
for _, (names, params) in dtensor_group.items():
pg = prelaunch if first_group else None
first_group = False
self.parallel(
names,
params,
group,
lr=lr,
weight_decay=weight_decay,
qk_logits=qk_logits,
prelaunch_gather=pg,
)
else:
_run_deferred_expert_ns()
if len(param_tensors) > 0:
self.base(
name_tensors,
param_tensors,
group,
lr=lr,
weight_decay=weight_decay,
qk_logits=qk_logits,
)
def _register_states_for_offload(self):
"""Register all optimizer state tensors with the CPU offload pool.
Called once after the first step when states have been lazily created.
Offloads all param states (momentum buffers for Muon, moment1/moment2
for AdamW) to free GPU memory between steps.
"""
pool = self._cpu_offload_pool
tracked = 0
for group in self.param_groups:
for p in group["params"]:
if p not in self.state:
continue
state = self.state[p]
if group.get("use_muon", False):
if "momentum_buffer" in state:
pool.track(state["momentum_buffer"])
tracked += 1
else:
if "moment1" in state:
pool.track(state["moment1"])
if "moment2" in state:
pool.track(state["moment2"])
tracked += 1
logger.info("[CPUOffload] Registered %d param states for offload",
tracked)
@torch.no_grad
def step(self, closure=None, qk_logits=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
to 1D tensors of shape (num_heads,), representing the maximum
QK logits across all tokens, computed as
(1 / sqrt(head_dim)) * (Q @ K^T).
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# H2D: reload optimizer states from CPU before computation.
if self.cpu_offload and self._offload_initialized:
self._cpu_offload_pool.reload()
logger.debug("[Muon.step] expert_keys=%s, %d param groups",
self.expert_keys, len(self.param_groups))
for i, group in enumerate(self.param_groups):
if group["use_muon"]:
logger.debug("[Muon.step] group %d: use_muon=True, %d params",
i, len(group["params"]))
self._step_muon(group, qk_logits=qk_logits)
else:
logger.debug(
"[Muon.step] group %d: use_muon=False (AdamW), %d params",
i, len(group["params"]))
step_adamw(self.state, group)
# D2H: offload optimizer states to CPU after computation.
if self.cpu_offload:
if not self._offload_initialized:
if self._cpu_offload_pool is None:
self._cpu_offload_pool = CPUOffloadPool()
self._register_states_for_offload()
self._offload_initialized = True
self._cpu_offload_pool.offload()
return loss
# ------------------------------------------------------------------
# CPU offload public helpers
# ------------------------------------------------------------------
def turn_on_cpu_offload(self):
"""Enable CPU offload for optimizer states."""
if self.cpu_offload:
return
logger.info("[Muon] turn_on_cpu_offload")
self.cpu_offload = True
if not self.state:
return
self._cpu_offload_pool = CPUOffloadPool()
self._offload_initialized = False
self._register_states_for_offload()
self._offload_initialized = True
self._cpu_offload_pool.offload()
def turn_off_cpu_offload(self):
"""Disable CPU offload and keep optimizer states resident on GPU."""
if not self.cpu_offload:
return
logger.info("[Muon] turn_off_cpu_offload")
if self._offload_initialized:
self._cpu_offload_pool.reload()
torch.cuda.current_stream().synchronize()
self._cpu_offload_pool = None
self._offload_initialized = False
self.cpu_offload = False
# ------------------------------------------------------------------
# Checkpoint support for cpu_offload
# ------------------------------------------------------------------
def state_dict(self) -> dict:
if self.cpu_offload:
raise RuntimeError(
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
)
return super().state_dict()
def load_state_dict(self, state_dict: dict) -> None:
if self.cpu_offload:
raise RuntimeError(
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
)
super().load_state_dict(state_dict)
# Invalidate adamw.py's module-level tensor caches so that
# the next step rebuilds them with the newly loaded state tensors.
_placement_cache.clear()
_tensor_cache.clear()