Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| import logging | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| from torch.distributed.tensor import DTensor | |
| from .core import normalize_fqn | |
| logger = logging.getLogger(__name__) | |
| def parse_qk_layer(name: str) -> tuple[str | None, int]: | |
| """ | |
| Parse a parameter name to check if it is a query/key projection layer | |
| and return (kind, layer_index). | |
| Supported kinds: | |
| MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' | |
| MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) | |
| Returns: | |
| (kind, layer_idx) or (None, -1) if not matched. | |
| Example: | |
| 'model.3.attn.wq.weight' -> ('wq', 3) | |
| 'model.5.attn.wk.weight' -> ('wk', 5) | |
| 'model.2.attn.q_proj.weight' -> ('q_proj', 2) | |
| 'model.7.attn.k_proj.weight' -> ('k_proj', 7) | |
| 'model.1.attn.wq_b.weight' -> ('wq_b', 1) | |
| 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) | |
| 'model.4.attn.v_proj.weight' -> (None, -1) | |
| """ | |
| parts = normalize_fqn(name).split('.') | |
| if len(parts) < 3: | |
| return None, -1 | |
| kind = parts[-2] | |
| layer_idx = -1 | |
| for part in reversed(parts): | |
| if part.isdigit(): | |
| layer_idx = int(part) | |
| break | |
| if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): | |
| return kind, layer_idx | |
| return None, -1 | |
| class QKClipInfo: | |
| """Per-parameter dynamic info computed from config + runtime logits.""" | |
| kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None | |
| indices: list[int] # which heads to consider for clipping | |
| head_dim: int # from config (qk_head_dim for MLA wq_b) | |
| threshold: float # from config | |
| logit: torch.Tensor | None | |
| # MLA-specific fields | |
| is_mla: bool = False | |
| qk_nope_head_dim: int = 0 | |
| qk_rope_head_dim: int = 0 | |
| v_head_dim: int = 0 | |
| def get_qk_clip_info(clip_config, n, qk_logits): | |
| """Extract QK clipping info for a named parameter. | |
| Args: | |
| clip_config: QK clipping configuration dict (or None). | |
| MHA/GQA keys: head_dim, threshold, q_indices, k_indices | |
| MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim | |
| n: Parameter name string. | |
| qk_logits: Dict mapping layer indices to logit tensors (or None). | |
| Returns: | |
| QKClipInfo instance with clipping configuration for this parameter. | |
| """ | |
| if clip_config is None: | |
| return None | |
| head_dim = clip_config.get('head_dim') | |
| threshold = clip_config.get('threshold') | |
| kind, layer_idx = parse_qk_layer(n) | |
| is_mla = clip_config.get('is_mla', False) | |
| logit, indices = None, [] | |
| if qk_logits is not None and kind is not None: | |
| logit = qk_logits[layer_idx] | |
| if isinstance(logit, DTensor): | |
| # In TP settings, qk_logits may be DTensor | |
| # We convert it to full tensor here for simplicity | |
| logit = logit.full_tensor() | |
| if kind in ('wq_b', 'wq', 'q_proj'): | |
| indices = clip_config.get('q_indices', []) or [] | |
| elif kind in ('wkv_b', 'wk', 'k_proj'): | |
| indices = clip_config.get('k_indices', []) or [] | |
| if is_mla: | |
| return QKClipInfo( | |
| kind=kind, | |
| indices=indices, | |
| head_dim=head_dim, | |
| threshold=threshold, | |
| logit=logit, | |
| is_mla=True, | |
| qk_nope_head_dim=clip_config['qk_nope_head_dim'], | |
| qk_rope_head_dim=clip_config['qk_rope_head_dim'], | |
| v_head_dim=clip_config['v_head_dim'], | |
| ) | |
| else: | |
| return QKClipInfo( | |
| kind=kind, | |
| indices=indices, | |
| head_dim=head_dim, | |
| threshold=threshold, | |
| logit=logit, | |
| ) | |
| def compute_scales(p, qk_clip_state): | |
| """Compute per-head scaling factors for QK clipping. | |
| Returns scales tensor (√γ per head) if any head exceeds threshold, else None. | |
| For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. | |
| """ | |
| kind = qk_clip_state.kind | |
| indices = qk_clip_state.indices | |
| head_dim = qk_clip_state.head_dim | |
| threshold = qk_clip_state.threshold | |
| logit = qk_clip_state.logit | |
| # Check if any head exceeds threshold before allocating. | |
| head_scales = {} | |
| for logit_idx, head_idx in enumerate(indices): | |
| v_ele = float(logit[logit_idx]) | |
| if v_ele > threshold: | |
| new_scale = math.sqrt(threshold / v_ele) | |
| if head_idx not in head_scales or new_scale < head_scales[head_idx]: | |
| head_scales[head_idx] = new_scale | |
| logger.info( | |
| f"[{kind}] Head {head_idx} exceeded threshold " | |
| f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" | |
| ) | |
| if not head_scales: | |
| return None | |
| # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows | |
| if qk_clip_state.is_mla and kind == 'wkv_b': | |
| effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim | |
| else: | |
| effective_head_dim = head_dim | |
| H_global = p.shape[0] // effective_head_dim | |
| scales_full = torch.ones(H_global, device=p.data.device) | |
| for head_idx, scale in head_scales.items(): | |
| scales_full[head_idx] = scale | |
| return scales_full | |
| def qk_clip(p, scales, info): | |
| """Apply per-head scaling to a Q/K projection weight matrix. | |
| Args: | |
| p: Parameter (nn.Parameter or raw tensor). | |
| scales: [n_heads] tensor, each element = √γ_h. | |
| info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. | |
| MLA sub-region scaling per Algorithm 1 (MuonClip): | |
| wq_b: q_nope rows → √γ, q_pe rows → γ | |
| wkv_b: k_nope rows → √γ, v rows → unchanged | |
| """ | |
| W = p.data if isinstance(p, torch.nn.Parameter) else p | |
| if not info.is_mla: | |
| # MHA/GQA: uniform √γ applied to all rows in each head | |
| W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) | |
| return | |
| # MLA: vectorized sub-region scaling within each head | |
| if info.kind == 'wq_b': | |
| qk_nope = info.qk_nope_head_dim | |
| qk_head_dim = qk_nope + info.qk_rope_head_dim | |
| W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] | |
| W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ | |
| W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, | |
| 1)) # q_pe → γ | |
| elif info.kind == 'wkv_b': | |
| qk_nope = info.qk_nope_head_dim | |
| kv_stride = qk_nope + info.v_head_dim | |
| W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] | |
| W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ | |
| # v rows: not touched (k_R shared rotary unchanged) | |