Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| import logging | |
| from collections import defaultdict | |
| from typing import cast | |
| import torch | |
| from torch.distributed.tensor import DTensor | |
| from torch.profiler import record_function | |
| logger = logging.getLogger(__name__) | |
| def fused_adamw( | |
| params: list[torch.Tensor], | |
| grads: list[torch.Tensor], | |
| exp_avgs: list[torch.Tensor], | |
| exp_avg_sqs: list[torch.Tensor], | |
| max_exp_avg_sqs: list[torch.Tensor], | |
| state_steps: list[torch.Tensor], | |
| amsgrad: bool, | |
| beta1: float, | |
| beta2: float, | |
| lr: float | torch.Tensor, | |
| weight_decay: float, | |
| eps: float, | |
| maximize: bool, | |
| ) -> None: | |
| if not params: | |
| return | |
| # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer | |
| # treating it as a scalar. | |
| lr_dict: dict | None = ({ | |
| lr.device: lr | |
| } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else None) | |
| grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( | |
| [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, | |
| state_steps] # type: ignore[list-item] | |
| ) | |
| for (device, _), ( | |
| ( | |
| device_params_, | |
| device_grads_, | |
| device_exp_avgs_, | |
| device_exp_avg_sqs_, | |
| device_max_exp_avg_sqs, | |
| device_state_steps_, | |
| ), | |
| _, | |
| ) in grouped_tensors.items(): | |
| device_params = cast(list[torch.Tensor], device_params_) | |
| device_grads = cast(list[torch.Tensor], device_grads_) | |
| device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) | |
| device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) | |
| device_state_steps = cast(list[torch.Tensor], device_state_steps_) | |
| if lr_dict is not None and device not in lr_dict: | |
| lr_dict[device] = lr.to( | |
| device=device, non_blocking=True) # type: ignore[union-attr] | |
| lr = lr_dict[device] | |
| torch._foreach_add_(device_state_steps, 1) | |
| func = torch._fused_adamw_ | |
| func( | |
| device_params, | |
| device_grads, | |
| device_exp_avgs, | |
| device_exp_avg_sqs, | |
| device_max_exp_avg_sqs, # type: ignore[arg-type] | |
| device_state_steps, | |
| amsgrad=amsgrad, | |
| lr=lr, # type: ignore[arg-type] | |
| beta1=beta1, | |
| beta2=beta2, | |
| weight_decay=weight_decay, | |
| eps=eps, | |
| maximize=maximize, | |
| ) | |
| def _to_local(t): | |
| """Unwrap DTensor to local tensor for fused ops.""" | |
| return t._local_tensor if isinstance(t, DTensor) else t | |
| # --------------------------------------------------------------------------- | |
| # Caches for eliminating per-step Python overhead. | |
| # | |
| # Placement grouping and tensor list assembly are identical every step | |
| # (params don't change placement, moment/step tensors are the same objects | |
| # after initialisation). We cache them keyed by id() of the param list | |
| # stored in param_groups (stable across steps). | |
| # | |
| # Only gradients change each step and must be collected fresh. | |
| # --------------------------------------------------------------------------- | |
| # id(group["params"]) → dict[placement_key, list[param]] | |
| _placement_cache: dict[int, dict[tuple, list]] = {} | |
| # id(placement_group_list) → (params_local, moment1, moment2, state_steps) | |
| _tensor_cache: dict[int, tuple[list, list, list, list]] = {} | |
| def _step_adamw_params_slow(optimizer_state, params, group): | |
| """Uncached fallback for the rare case where some params lack grads.""" | |
| params_with_grads = [] | |
| grads = [] | |
| moment1 = [] | |
| moment2 = [] | |
| state_steps = [] | |
| for p in params: | |
| g = p.grad | |
| if g is None: | |
| continue | |
| state = optimizer_state[p] | |
| params_with_grads.append(_to_local(p)) | |
| grads.append(_to_local(g)) | |
| if "step" not in state: | |
| state["step"] = torch.zeros((), | |
| dtype=torch.float32, | |
| device=p.device) | |
| state["moment1"] = torch.zeros_like(g) | |
| state["moment2"] = torch.zeros_like(g) | |
| moment1.append(_to_local(state["moment1"])) | |
| moment2.append(_to_local(state["moment2"])) | |
| if not isinstance(state["step"], torch.Tensor): | |
| state["step"] = torch.tensor(state["step"], | |
| dtype=torch.float32, | |
| device=p.device) | |
| state_steps.append(state["step"]) | |
| if not params_with_grads: | |
| return | |
| lr = group["lr"] | |
| beta1, beta2 = group["adamw_betas"] | |
| eps = group["adamw_eps"] | |
| weight_decay = group["weight_decay"] | |
| fused_adamw( | |
| params_with_grads, | |
| grads, | |
| moment1, | |
| moment2, | |
| [], | |
| state_steps, | |
| amsgrad=False, | |
| beta1=beta1, | |
| beta2=beta2, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| eps=eps, | |
| maximize=False, | |
| ) | |
| def step_adamw_params(optimizer_state, params, group): | |
| """Run fused AdamW on a list of parameters sharing the same placement. | |
| After the first call, cached tensor lists (params_local, moment1, | |
| moment2, state_steps) are reused — only gradients are collected fresh. | |
| Args: | |
| optimizer_state: The optimizer's state dict (self.state in Muon). | |
| params: List of parameters to update. | |
| group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. | |
| """ | |
| # Collect grads — the only thing that changes each step. | |
| with record_function("adamw::collect_grads"): | |
| grads = [] | |
| for p in params: | |
| g = p.grad | |
| if g is None: | |
| # Rare: fall back to slow path that filters per-param. | |
| _step_adamw_params_slow(optimizer_state, params, group) | |
| return | |
| grads.append(_to_local(g)) | |
| tensor_key = id(params) | |
| if tensor_key not in _tensor_cache: | |
| with record_function("adamw::init_tensor_cache"): | |
| params_local = [] | |
| moment1 = [] | |
| moment2 = [] | |
| state_steps = [] | |
| for p in params: | |
| state = optimizer_state[p] | |
| params_local.append(_to_local(p)) | |
| if "step" not in state: | |
| state["step"] = torch.zeros((), | |
| dtype=torch.float32, | |
| device=p.device) | |
| state["moment1"] = torch.zeros_like(p.grad) | |
| state["moment2"] = torch.zeros_like(p.grad) | |
| moment1.append(_to_local(state["moment1"])) | |
| moment2.append(_to_local(state["moment2"])) | |
| if not isinstance(state["step"], torch.Tensor): | |
| state["step"] = torch.tensor(state["step"], | |
| dtype=torch.float32, | |
| device=p.device) | |
| state_steps.append(state["step"]) | |
| _tensor_cache[tensor_key] = (params_local, moment1, moment2, | |
| state_steps) | |
| params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] | |
| lr = group["lr"] | |
| beta1, beta2 = group["adamw_betas"] | |
| eps = group["adamw_eps"] | |
| weight_decay = group["weight_decay"] | |
| with record_function("adamw::fused_adamw"): | |
| fused_adamw( | |
| params_local, | |
| grads, | |
| moment1, | |
| moment2, | |
| [], | |
| state_steps, | |
| amsgrad=False, | |
| beta1=beta1, | |
| beta2=beta2, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| eps=eps, | |
| maximize=False, | |
| ) | |
| def step_adamw(optimizer_state, group): | |
| """Dispatch AdamW step, grouping parameters by type and placement. | |
| Placement grouping is cached after the first call since params never | |
| change their placement between steps. | |
| Args: | |
| optimizer_state: The optimizer's state dict (self.state in Muon). | |
| group: Parameter group dict. | |
| """ | |
| params = group["params"] | |
| placement_key = id(params) | |
| if placement_key not in _placement_cache: | |
| with record_function("adamw::group_by_placement"): | |
| placement_to_params: dict[tuple, | |
| list[torch.Tensor]] = defaultdict(list) | |
| for p in params: | |
| match p: | |
| case DTensor(): | |
| logger.debug( | |
| "[AdamW] DTensor param: shape=%s, placements=%s, " | |
| "mesh=%s, grad=%s", p.shape, p.placements, | |
| p.device_mesh.mesh_dim_names, | |
| p.grad.shape if p.grad is not None else None) | |
| placement_to_params[tuple( | |
| [p.placements, p.device_mesh])].append(p) | |
| case torch.Tensor(): | |
| logger.debug( | |
| "[AdamW] plain param: shape=%s, grad=%s", p.shape, | |
| p.grad.shape if p.grad is not None else None) | |
| placement_to_params[tuple([torch.Tensor, | |
| None])].append(p) | |
| logger.debug("[AdamW] %d placement groups, %d total params", | |
| len(placement_to_params), len(params)) | |
| _placement_cache[placement_key] = dict(placement_to_params) | |
| for group_params in _placement_cache[placement_key].values(): | |
| step_adamw_params(optimizer_state, group_params, group) | |