import logging import math from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor from .core import normalize_fqn logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer and return (kind, layer_index). Supported kinds: MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. Example: 'model.3.attn.wq.weight' -> ('wq', 3) 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.1.attn.wq_b.weight' -> ('wq_b', 1) 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 kind = parts[-2] layer_idx = -1 for part in reversed(parts): if part.isdigit(): layer_idx = int(part) break if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None # MLA-specific fields is_mla: bool = False qk_nope_head_dim: int = 0 qk_rope_head_dim: int = 0 v_head_dim: int = 0 def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). MHA/GQA keys: head_dim, threshold, q_indices, k_indices MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). Returns: QKClipInfo instance with clipping configuration for this parameter. """ if clip_config is None: return None head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() if kind in ('wq_b', 'wq', 'q_proj'): indices = clip_config.get('q_indices', []) or [] elif kind in ('wkv_b', 'wk', 'k_proj'): indices = clip_config.get('k_indices', []) or [] if is_mla: return QKClipInfo( kind=kind, indices=indices, head_dim=head_dim, threshold=threshold, logit=logit, is_mla=True, qk_nope_head_dim=clip_config['qk_nope_head_dim'], qk_rope_head_dim=clip_config['qk_rope_head_dim'], v_head_dim=clip_config['v_head_dim'], ) else: return QKClipInfo( kind=kind, indices=indices, head_dim=head_dim, threshold=threshold, logit=logit, ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. Returns scales tensor (√γ per head) if any head exceeds threshold, else None. For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices head_dim = qk_clip_state.head_dim threshold = qk_clip_state.threshold logit = qk_clip_state.logit # Check if any head exceeds threshold before allocating. head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) if head_idx not in head_scales or new_scale < head_scales[head_idx]: head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) if not head_scales: return None # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows if qk_clip_state.is_mla and kind == 'wkv_b': effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim else: effective_head_dim = head_dim H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full def qk_clip(p, scales, info): """Apply per-head scaling to a Q/K projection weight matrix. Args: p: Parameter (nn.Parameter or raw tensor). scales: [n_heads] tensor, each element = √γ_h. info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. MLA sub-region scaling per Algorithm 1 (MuonClip): wq_b: q_nope rows → √γ, q_pe rows → γ wkv_b: k_nope rows → √γ, v rows → unchanged """ W = p.data if isinstance(p, torch.nn.Parameter) else p if not info.is_mla: # MHA/GQA: uniform √γ applied to all rows in each head W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) return # MLA: vectorized sub-region scaling within each head if info.kind == 'wq_b': qk_nope = info.qk_nope_head_dim qk_head_dim = qk_nope + info.qk_rope_head_dim W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, 1)) # q_pe → γ elif info.kind == 'wkv_b': qk_nope = info.qk_nope_head_dim kv_stride = qk_nope + info.v_head_dim W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ # v rows: not touched (k_R shared rotary unchanged)