Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
File size: 6,801 Bytes
33929c0 10848ab 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 10848ab 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 e8e2c81 33929c0 10848ab 33929c0 10848ab 33929c0 10848ab e8e2c81 10848ab 33929c0 e8e2c81 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | import logging
import math
from dataclasses import dataclass
import torch
from torch.distributed.tensor import DTensor
from .core import normalize_fqn
logger = logging.getLogger(__name__)
def parse_qk_layer(name: str) -> tuple[str | None, int]:
"""
Parse a parameter name to check if it is a query/key projection layer
and return (kind, layer_index).
Supported kinds:
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
Returns:
(kind, layer_idx) or (None, -1) if not matched.
Example:
'model.3.attn.wq.weight' -> ('wq', 3)
'model.5.attn.wk.weight' -> ('wk', 5)
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
'model.4.attn.v_proj.weight' -> (None, -1)
"""
parts = normalize_fqn(name).split('.')
if len(parts) < 3:
return None, -1
kind = parts[-2]
layer_idx = -1
for part in reversed(parts):
if part.isdigit():
layer_idx = int(part)
break
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
return kind, layer_idx
return None, -1
@dataclass
class QKClipInfo:
"""Per-parameter dynamic info computed from config + runtime logits."""
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
indices: list[int] # which heads to consider for clipping
head_dim: int # from config (qk_head_dim for MLA wq_b)
threshold: float # from config
logit: torch.Tensor | None
# MLA-specific fields
is_mla: bool = False
qk_nope_head_dim: int = 0
qk_rope_head_dim: int = 0
v_head_dim: int = 0
def get_qk_clip_info(clip_config, n, qk_logits):
"""Extract QK clipping info for a named parameter.
Args:
clip_config: QK clipping configuration dict (or None).
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
n: Parameter name string.
qk_logits: Dict mapping layer indices to logit tensors (or None).
Returns:
QKClipInfo instance with clipping configuration for this parameter.
"""
if clip_config is None:
return None
head_dim = clip_config.get('head_dim')
threshold = clip_config.get('threshold')
kind, layer_idx = parse_qk_layer(n)
is_mla = clip_config.get('is_mla', False)
logit, indices = None, []
if qk_logits is not None and kind is not None:
logit = qk_logits[layer_idx]
if isinstance(logit, DTensor):
# In TP settings, qk_logits may be DTensor
# We convert it to full tensor here for simplicity
logit = logit.full_tensor()
if kind in ('wq_b', 'wq', 'q_proj'):
indices = clip_config.get('q_indices', []) or []
elif kind in ('wkv_b', 'wk', 'k_proj'):
indices = clip_config.get('k_indices', []) or []
if is_mla:
return QKClipInfo(
kind=kind,
indices=indices,
head_dim=head_dim,
threshold=threshold,
logit=logit,
is_mla=True,
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
v_head_dim=clip_config['v_head_dim'],
)
else:
return QKClipInfo(
kind=kind,
indices=indices,
head_dim=head_dim,
threshold=threshold,
logit=logit,
)
def compute_scales(p, qk_clip_state):
"""Compute per-head scaling factors for QK clipping.
Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
"""
kind = qk_clip_state.kind
indices = qk_clip_state.indices
head_dim = qk_clip_state.head_dim
threshold = qk_clip_state.threshold
logit = qk_clip_state.logit
# Check if any head exceeds threshold before allocating.
head_scales = {}
for logit_idx, head_idx in enumerate(indices):
v_ele = float(logit[logit_idx])
if v_ele > threshold:
new_scale = math.sqrt(threshold / v_ele)
if head_idx not in head_scales or new_scale < head_scales[head_idx]:
head_scales[head_idx] = new_scale
logger.info(
f"[{kind}] Head {head_idx} exceeded threshold "
f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
)
if not head_scales:
return None
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
if qk_clip_state.is_mla and kind == 'wkv_b':
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
else:
effective_head_dim = head_dim
H_global = p.shape[0] // effective_head_dim
scales_full = torch.ones(H_global, device=p.data.device)
for head_idx, scale in head_scales.items():
scales_full[head_idx] = scale
return scales_full
def qk_clip(p, scales, info):
"""Apply per-head scaling to a Q/K projection weight matrix.
Args:
p: Parameter (nn.Parameter or raw tensor).
scales: [n_heads] tensor, each element = √γ_h.
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
MLA sub-region scaling per Algorithm 1 (MuonClip):
wq_b: q_nope rows → √γ, q_pe rows → γ
wkv_b: k_nope rows → √γ, v rows → unchanged
"""
W = p.data if isinstance(p, torch.nn.Parameter) else p
if not info.is_mla:
# MHA/GQA: uniform √γ applied to all rows in each head
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
return
# MLA: vectorized sub-region scaling within each head
if info.kind == 'wq_b':
qk_nope = info.qk_nope_head_dim
qk_head_dim = qk_nope + info.qk_rope_head_dim
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
1)) # q_pe → γ
elif info.kind == 'wkv_b':
qk_nope = info.qk_nope_head_dim
kv_stride = qk_nope + info.v_head_dim
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ
# v rows: not touched (k_R shared rotary unchanged)
|