Kernels
dongseokmotif's picture
feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
e8e2c81 unverified
import logging
import math
from dataclasses import dataclass
import torch
from torch.distributed.tensor import DTensor
from .core import normalize_fqn
logger = logging.getLogger(__name__)
def parse_qk_layer(name: str) -> tuple[str | None, int]:
"""
Parse a parameter name to check if it is a query/key projection layer
and return (kind, layer_index).
Supported kinds:
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
Returns:
(kind, layer_idx) or (None, -1) if not matched.
Example:
'model.3.attn.wq.weight' -> ('wq', 3)
'model.5.attn.wk.weight' -> ('wk', 5)
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
'model.4.attn.v_proj.weight' -> (None, -1)
"""
parts = normalize_fqn(name).split('.')
if len(parts) < 3:
return None, -1
kind = parts[-2]
layer_idx = -1
for part in reversed(parts):
if part.isdigit():
layer_idx = int(part)
break
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
return kind, layer_idx
return None, -1
@dataclass
class QKClipInfo:
"""Per-parameter dynamic info computed from config + runtime logits."""
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
indices: list[int] # which heads to consider for clipping
head_dim: int # from config (qk_head_dim for MLA wq_b)
threshold: float # from config
logit: torch.Tensor | None
# MLA-specific fields
is_mla: bool = False
qk_nope_head_dim: int = 0
qk_rope_head_dim: int = 0
v_head_dim: int = 0
def get_qk_clip_info(clip_config, n, qk_logits):
"""Extract QK clipping info for a named parameter.
Args:
clip_config: QK clipping configuration dict (or None).
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
n: Parameter name string.
qk_logits: Dict mapping layer indices to logit tensors (or None).
Returns:
QKClipInfo instance with clipping configuration for this parameter.
"""
if clip_config is None:
return None
head_dim = clip_config.get('head_dim')
threshold = clip_config.get('threshold')
kind, layer_idx = parse_qk_layer(n)
is_mla = clip_config.get('is_mla', False)
logit, indices = None, []
if qk_logits is not None and kind is not None:
logit = qk_logits[layer_idx]
if isinstance(logit, DTensor):
# In TP settings, qk_logits may be DTensor
# We convert it to full tensor here for simplicity
logit = logit.full_tensor()
if kind in ('wq_b', 'wq', 'q_proj'):
indices = clip_config.get('q_indices', []) or []
elif kind in ('wkv_b', 'wk', 'k_proj'):
indices = clip_config.get('k_indices', []) or []
if is_mla:
return QKClipInfo(
kind=kind,
indices=indices,
head_dim=head_dim,
threshold=threshold,
logit=logit,
is_mla=True,
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
v_head_dim=clip_config['v_head_dim'],
)
else:
return QKClipInfo(
kind=kind,
indices=indices,
head_dim=head_dim,
threshold=threshold,
logit=logit,
)
def compute_scales(p, qk_clip_state):
"""Compute per-head scaling factors for QK clipping.
Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
"""
kind = qk_clip_state.kind
indices = qk_clip_state.indices
head_dim = qk_clip_state.head_dim
threshold = qk_clip_state.threshold
logit = qk_clip_state.logit
# Check if any head exceeds threshold before allocating.
head_scales = {}
for logit_idx, head_idx in enumerate(indices):
v_ele = float(logit[logit_idx])
if v_ele > threshold:
new_scale = math.sqrt(threshold / v_ele)
if head_idx not in head_scales or new_scale < head_scales[head_idx]:
head_scales[head_idx] = new_scale
logger.info(
f"[{kind}] Head {head_idx} exceeded threshold "
f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
)
if not head_scales:
return None
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
if qk_clip_state.is_mla and kind == 'wkv_b':
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
else:
effective_head_dim = head_dim
H_global = p.shape[0] // effective_head_dim
scales_full = torch.ones(H_global, device=p.data.device)
for head_idx, scale in head_scales.items():
scales_full[head_idx] = scale
return scales_full
def qk_clip(p, scales, info):
"""Apply per-head scaling to a Q/K projection weight matrix.
Args:
p: Parameter (nn.Parameter or raw tensor).
scales: [n_heads] tensor, each element = √γ_h.
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
MLA sub-region scaling per Algorithm 1 (MuonClip):
wq_b: q_nope rows → √γ, q_pe rows → γ
wkv_b: k_nope rows → √γ, v rows → unchanged
"""
W = p.data if isinstance(p, torch.nn.Parameter) else p
if not info.is_mla:
# MHA/GQA: uniform √γ applied to all rows in each head
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
return
# MLA: vectorized sub-region scaling within each head
if info.kind == 'wq_b':
qk_nope = info.qk_nope_head_dim
qk_head_dim = qk_nope + info.qk_rope_head_dim
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
1)) # q_pe → γ
elif info.kind == 'wkv_b':
qk_nope = info.qk_nope_head_dim
kv_stride = qk_nope + info.v_head_dim
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ
# v rows: not touched (k_R shared rotary unchanged)