Kernels
TaehyunKim
draft commit for cpu_offload (#23)
10848ab unverified
import logging
from collections import defaultdict
from typing import cast
import torch
from torch.distributed.tensor import DTensor
from torch.profiler import record_function
logger = logging.getLogger(__name__)
def fused_adamw(
params: list[torch.Tensor],
grads: list[torch.Tensor],
exp_avgs: list[torch.Tensor],
exp_avg_sqs: list[torch.Tensor],
max_exp_avg_sqs: list[torch.Tensor],
state_steps: list[torch.Tensor],
amsgrad: bool,
beta1: float,
beta2: float,
lr: float | torch.Tensor,
weight_decay: float,
eps: float,
maximize: bool,
) -> None:
if not params:
return
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict: dict | None = ({
lr.device: lr
} if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None)
grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
state_steps] # type: ignore[list-item]
)
for (device, _), (
(
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs,
device_state_steps_,
),
_,
) in grouped_tensors.items():
device_params = cast(list[torch.Tensor], device_params_)
device_grads = cast(list[torch.Tensor], device_grads_)
device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
device_state_steps = cast(list[torch.Tensor], device_state_steps_)
if lr_dict is not None and device not in lr_dict:
lr_dict[device] = lr.to(
device=device, non_blocking=True) # type: ignore[union-attr]
lr = lr_dict[device]
torch._foreach_add_(device_state_steps, 1)
func = torch._fused_adamw_
func(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs, # type: ignore[arg-type]
device_state_steps,
amsgrad=amsgrad,
lr=lr, # type: ignore[arg-type]
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
)
def _to_local(t):
"""Unwrap DTensor to local tensor for fused ops."""
return t._local_tensor if isinstance(t, DTensor) else t
# ---------------------------------------------------------------------------
# Caches for eliminating per-step Python overhead.
#
# Placement grouping and tensor list assembly are identical every step
# (params don't change placement, moment/step tensors are the same objects
# after initialisation). We cache them keyed by id() of the param list
# stored in param_groups (stable across steps).
#
# Only gradients change each step and must be collected fresh.
# ---------------------------------------------------------------------------
# id(group["params"]) → dict[placement_key, list[param]]
_placement_cache: dict[int, dict[tuple, list]] = {}
# id(placement_group_list) → (params_local, moment1, moment2, state_steps)
_tensor_cache: dict[int, tuple[list, list, list, list]] = {}
def _step_adamw_params_slow(optimizer_state, params, group):
"""Uncached fallback for the rare case where some params lack grads."""
params_with_grads = []
grads = []
moment1 = []
moment2 = []
state_steps = []
for p in params:
g = p.grad
if g is None:
continue
state = optimizer_state[p]
params_with_grads.append(_to_local(p))
grads.append(_to_local(g))
if "step" not in state:
state["step"] = torch.zeros((),
dtype=torch.float32,
device=p.device)
state["moment1"] = torch.zeros_like(g)
state["moment2"] = torch.zeros_like(g)
moment1.append(_to_local(state["moment1"]))
moment2.append(_to_local(state["moment2"]))
if not isinstance(state["step"], torch.Tensor):
state["step"] = torch.tensor(state["step"],
dtype=torch.float32,
device=p.device)
state_steps.append(state["step"])
if not params_with_grads:
return
lr = group["lr"]
beta1, beta2 = group["adamw_betas"]
eps = group["adamw_eps"]
weight_decay = group["weight_decay"]
fused_adamw(
params_with_grads,
grads,
moment1,
moment2,
[],
state_steps,
amsgrad=False,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
def step_adamw_params(optimizer_state, params, group):
"""Run fused AdamW on a list of parameters sharing the same placement.
After the first call, cached tensor lists (params_local, moment1,
moment2, state_steps) are reused — only gradients are collected fresh.
Args:
optimizer_state: The optimizer's state dict (self.state in Muon).
params: List of parameters to update.
group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay.
"""
# Collect grads — the only thing that changes each step.
with record_function("adamw::collect_grads"):
grads = []
for p in params:
g = p.grad
if g is None:
# Rare: fall back to slow path that filters per-param.
_step_adamw_params_slow(optimizer_state, params, group)
return
grads.append(_to_local(g))
tensor_key = id(params)
if tensor_key not in _tensor_cache:
with record_function("adamw::init_tensor_cache"):
params_local = []
moment1 = []
moment2 = []
state_steps = []
for p in params:
state = optimizer_state[p]
params_local.append(_to_local(p))
if "step" not in state:
state["step"] = torch.zeros((),
dtype=torch.float32,
device=p.device)
state["moment1"] = torch.zeros_like(p.grad)
state["moment2"] = torch.zeros_like(p.grad)
moment1.append(_to_local(state["moment1"]))
moment2.append(_to_local(state["moment2"]))
if not isinstance(state["step"], torch.Tensor):
state["step"] = torch.tensor(state["step"],
dtype=torch.float32,
device=p.device)
state_steps.append(state["step"])
_tensor_cache[tensor_key] = (params_local, moment1, moment2,
state_steps)
params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key]
lr = group["lr"]
beta1, beta2 = group["adamw_betas"]
eps = group["adamw_eps"]
weight_decay = group["weight_decay"]
with record_function("adamw::fused_adamw"):
fused_adamw(
params_local,
grads,
moment1,
moment2,
[],
state_steps,
amsgrad=False,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=False,
)
def step_adamw(optimizer_state, group):
"""Dispatch AdamW step, grouping parameters by type and placement.
Placement grouping is cached after the first call since params never
change their placement between steps.
Args:
optimizer_state: The optimizer's state dict (self.state in Muon).
group: Parameter group dict.
"""
params = group["params"]
placement_key = id(params)
if placement_key not in _placement_cache:
with record_function("adamw::group_by_placement"):
placement_to_params: dict[tuple,
list[torch.Tensor]] = defaultdict(list)
for p in params:
match p:
case DTensor():
logger.debug(
"[AdamW] DTensor param: shape=%s, placements=%s, "
"mesh=%s, grad=%s", p.shape, p.placements,
p.device_mesh.mesh_dim_names,
p.grad.shape if p.grad is not None else None)
placement_to_params[tuple(
[p.placements, p.device_mesh])].append(p)
case torch.Tensor():
logger.debug(
"[AdamW] plain param: shape=%s, grad=%s", p.shape,
p.grad.shape if p.grad is not None else None)
placement_to_params[tuple([torch.Tensor,
None])].append(p)
logger.debug("[AdamW] %d placement groups, %d total params",
len(placement_to_params), len(params))
_placement_cache[placement_key] = dict(placement_to_params)
for group_params in _placement_cache[placement_key].values():
step_adamw_params(optimizer_state, group_params, group)