import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor from torch.profiler import record_function logger = logging.getLogger(__name__) def fused_adamw( params: list[torch.Tensor], grads: list[torch.Tensor], exp_avgs: list[torch.Tensor], exp_avg_sqs: list[torch.Tensor], max_exp_avg_sqs: list[torch.Tensor], state_steps: list[torch.Tensor], amsgrad: bool, beta1: float, beta2: float, lr: float | torch.Tensor, weight_decay: float, eps: float, maximize: bool, ) -> None: if not params: return # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer # treating it as a scalar. lr_dict: dict | None = ({ lr.device: lr } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] ) for (device, _), ( ( device_params_, device_grads_, device_exp_avgs_, device_exp_avg_sqs_, device_max_exp_avg_sqs, device_state_steps_, ), _, ) in grouped_tensors.items(): device_params = cast(list[torch.Tensor], device_params_) device_grads = cast(list[torch.Tensor], device_grads_) device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) device_state_steps = cast(list[torch.Tensor], device_state_steps_) if lr_dict is not None and device not in lr_dict: lr_dict[device] = lr.to( device=device, non_blocking=True) # type: ignore[union-attr] lr = lr_dict[device] torch._foreach_add_(device_state_steps, 1) func = torch._fused_adamw_ func( device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs, # type: ignore[arg-type] device_state_steps, amsgrad=amsgrad, lr=lr, # type: ignore[arg-type] beta1=beta1, beta2=beta2, weight_decay=weight_decay, eps=eps, maximize=maximize, ) def _to_local(t): """Unwrap DTensor to local tensor for fused ops.""" return t._local_tensor if isinstance(t, DTensor) else t # --------------------------------------------------------------------------- # Caches for eliminating per-step Python overhead. # # Placement grouping and tensor list assembly are identical every step # (params don't change placement, moment/step tensors are the same objects # after initialisation). We cache them keyed by id() of the param list # stored in param_groups (stable across steps). # # Only gradients change each step and must be collected fresh. # --------------------------------------------------------------------------- # id(group["params"]) → dict[placement_key, list[param]] _placement_cache: dict[int, dict[tuple, list]] = {} # id(placement_group_list) → (params_local, moment1, moment2, state_steps) _tensor_cache: dict[int, tuple[list, list, list, list]] = {} def _step_adamw_params_slow(optimizer_state, params, group): """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] state_steps = [] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] params_with_grads.append(_to_local(p)) grads.append(_to_local(g)) if "step" not in state: state["step"] = torch.zeros((), dtype=torch.float32, device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) moment1.append(_to_local(state["moment1"])) moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): state["step"] = torch.tensor(state["step"], dtype=torch.float32, device=p.device) state_steps.append(state["step"]) if not params_with_grads: return lr = group["lr"] beta1, beta2 = group["adamw_betas"] eps = group["adamw_eps"] weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, [], state_steps, amsgrad=False, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay, eps=eps, maximize=False, ) def step_adamw_params(optimizer_state, params, group): """Run fused AdamW on a list of parameters sharing the same placement. After the first call, cached tensor lists (params_local, moment1, moment2, state_steps) are reused — only gradients are collected fresh. Args: optimizer_state: The optimizer's state dict (self.state in Muon). params: List of parameters to update. group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. """ # Collect grads — the only thing that changes each step. with record_function("adamw::collect_grads"): grads = [] for p in params: g = p.grad if g is None: # Rare: fall back to slow path that filters per-param. _step_adamw_params_slow(optimizer_state, params, group) return grads.append(_to_local(g)) tensor_key = id(params) if tensor_key not in _tensor_cache: with record_function("adamw::init_tensor_cache"): params_local = [] moment1 = [] moment2 = [] state_steps = [] for p in params: state = optimizer_state[p] params_local.append(_to_local(p)) if "step" not in state: state["step"] = torch.zeros((), dtype=torch.float32, device=p.device) state["moment1"] = torch.zeros_like(p.grad) state["moment2"] = torch.zeros_like(p.grad) moment1.append(_to_local(state["moment1"])) moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): state["step"] = torch.tensor(state["step"], dtype=torch.float32, device=p.device) state_steps.append(state["step"]) _tensor_cache[tensor_key] = (params_local, moment1, moment2, state_steps) params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] lr = group["lr"] beta1, beta2 = group["adamw_betas"] eps = group["adamw_eps"] weight_decay = group["weight_decay"] with record_function("adamw::fused_adamw"): fused_adamw( params_local, grads, moment1, moment2, [], state_steps, amsgrad=False, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay, eps=eps, maximize=False, ) def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. Placement grouping is cached after the first call since params never change their placement between steps. Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] placement_key = id(params) if placement_key not in _placement_cache: with record_function("adamw::group_by_placement"): placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) for p in params: match p: case DTensor(): logger.debug( "[AdamW] DTensor param: shape=%s, placements=%s, " "mesh=%s, grad=%s", p.shape, p.placements, p.device_mesh.mesh_dim_names, p.grad.shape if p.grad is not None else None) placement_to_params[tuple( [p.placements, p.device_mesh])].append(p) case torch.Tensor(): logger.debug( "[AdamW] plain param: shape=%s, grad=%s", p.shape, p.grad.shape if p.grad is not None else None) placement_to_params[tuple([torch.Tensor, None])].append(p) logger.debug("[AdamW] %d placement groups, %d total params", len(placement_to_params), len(params)) _placement_cache[placement_key] = dict(placement_to_params) for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group)