Kernels
File size: 6,801 Bytes
33929c0
 
 
 
 
 
 
10848ab
 
33929c0
 
 
 
 
 
e8e2c81
 
 
 
 
33929c0
 
 
 
 
 
 
 
 
e8e2c81
 
33929c0
 
10848ab
33929c0
 
 
 
 
 
 
 
 
 
 
e8e2c81
33929c0
 
 
 
 
 
 
 
e8e2c81
33929c0
e8e2c81
33929c0
 
 
e8e2c81
 
 
 
 
 
33929c0
 
 
 
 
 
e8e2c81
 
33929c0
 
 
 
 
 
 
 
 
 
 
 
e8e2c81
33929c0
 
 
 
 
 
 
 
 
e8e2c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33929c0
 
 
 
 
e8e2c81
 
33929c0
 
 
 
 
 
 
10848ab
 
33929c0
 
 
 
10848ab
 
33929c0
 
 
 
 
10848ab
 
 
e8e2c81
 
 
 
 
 
 
10848ab
 
 
 
33929c0
 
e8e2c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import logging
import math
from dataclasses import dataclass

import torch
from torch.distributed.tensor import DTensor

from .core import normalize_fqn

logger = logging.getLogger(__name__)


def parse_qk_layer(name: str) -> tuple[str | None, int]:
    """
    Parse a parameter name to check if it is a query/key projection layer
    and return (kind, layer_index).

    Supported kinds:
        MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
        MLA:     'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)

    Returns:
        (kind, layer_idx) or (None, -1) if not matched.

    Example:
        'model.3.attn.wq.weight'      -> ('wq', 3)
        'model.5.attn.wk.weight'      -> ('wk', 5)
        'model.2.attn.q_proj.weight'  -> ('q_proj', 2)
        'model.7.attn.k_proj.weight'  -> ('k_proj', 7)
        'model.1.attn.wq_b.weight'    -> ('wq_b', 1)
        'model.0.attn.wkv_b.weight'   -> ('wkv_b', 0)
        'model.4.attn.v_proj.weight'  -> (None, -1)
    """
    parts = normalize_fqn(name).split('.')
    if len(parts) < 3:
        return None, -1

    kind = parts[-2]

    layer_idx = -1
    for part in reversed(parts):
        if part.isdigit():
            layer_idx = int(part)
            break

    if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
        return kind, layer_idx

    return None, -1


@dataclass
class QKClipInfo:
    """Per-parameter dynamic info computed from config + runtime logits."""
    kind: str | None  # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
    indices: list[int]  # which heads to consider for clipping
    head_dim: int  # from config (qk_head_dim for MLA wq_b)
    threshold: float  # from config
    logit: torch.Tensor | None

    # MLA-specific fields
    is_mla: bool = False
    qk_nope_head_dim: int = 0
    qk_rope_head_dim: int = 0
    v_head_dim: int = 0


def get_qk_clip_info(clip_config, n, qk_logits):
    """Extract QK clipping info for a named parameter.

    Args:
        clip_config: QK clipping configuration dict (or None).
            MHA/GQA keys: head_dim, threshold, q_indices, k_indices
            MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
        n: Parameter name string.
        qk_logits: Dict mapping layer indices to logit tensors (or None).

    Returns:
        QKClipInfo instance with clipping configuration for this parameter.
    """
    if clip_config is None:
        return None

    head_dim = clip_config.get('head_dim')
    threshold = clip_config.get('threshold')
    kind, layer_idx = parse_qk_layer(n)
    is_mla = clip_config.get('is_mla', False)

    logit, indices = None, []
    if qk_logits is not None and kind is not None:
        logit = qk_logits[layer_idx]
        if isinstance(logit, DTensor):
            # In TP settings, qk_logits may be DTensor
            # We convert it to full tensor here for simplicity
            logit = logit.full_tensor()

        if kind in ('wq_b', 'wq', 'q_proj'):
            indices = clip_config.get('q_indices', []) or []
        elif kind in ('wkv_b', 'wk', 'k_proj'):
            indices = clip_config.get('k_indices', []) or []

    if is_mla:
        return QKClipInfo(
            kind=kind,
            indices=indices,
            head_dim=head_dim,
            threshold=threshold,
            logit=logit,
            is_mla=True,
            qk_nope_head_dim=clip_config['qk_nope_head_dim'],
            qk_rope_head_dim=clip_config['qk_rope_head_dim'],
            v_head_dim=clip_config['v_head_dim'],
        )
    else:
        return QKClipInfo(
            kind=kind,
            indices=indices,
            head_dim=head_dim,
            threshold=threshold,
            logit=logit,
        )


def compute_scales(p, qk_clip_state):
    """Compute per-head scaling factors for QK clipping.

    Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
    For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
    """
    kind = qk_clip_state.kind
    indices = qk_clip_state.indices
    head_dim = qk_clip_state.head_dim
    threshold = qk_clip_state.threshold
    logit = qk_clip_state.logit

    # Check if any head exceeds threshold before allocating.
    head_scales = {}
    for logit_idx, head_idx in enumerate(indices):
        v_ele = float(logit[logit_idx])
        if v_ele > threshold:
            new_scale = math.sqrt(threshold / v_ele)
            if head_idx not in head_scales or new_scale < head_scales[head_idx]:
                head_scales[head_idx] = new_scale
                logger.info(
                    f"[{kind}] Head {head_idx} exceeded threshold "
                    f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
                )

    if not head_scales:
        return None

    # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
    if qk_clip_state.is_mla and kind == 'wkv_b':
        effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
    else:
        effective_head_dim = head_dim

    H_global = p.shape[0] // effective_head_dim
    scales_full = torch.ones(H_global, device=p.data.device)
    for head_idx, scale in head_scales.items():
        scales_full[head_idx] = scale
    return scales_full


def qk_clip(p, scales, info):
    """Apply per-head scaling to a Q/K projection weight matrix.

    Args:
        p: Parameter (nn.Parameter or raw tensor).
        scales: [n_heads] tensor, each element = √γ_h.
        info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.

    MLA sub-region scaling per Algorithm 1 (MuonClip):
        wq_b: q_nope rows → √γ,  q_pe rows → γ
        wkv_b: k_nope rows → √γ, v rows → unchanged
    """
    W = p.data if isinstance(p, torch.nn.Parameter) else p

    if not info.is_mla:
        # MHA/GQA: uniform √γ applied to all rows in each head
        W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
        return

    # MLA: vectorized sub-region scaling within each head
    if info.kind == 'wq_b':
        qk_nope = info.qk_nope_head_dim
        qk_head_dim = qk_nope + info.qk_rope_head_dim
        W_3d = W.view(-1, qk_head_dim, W.shape[1])  # [H, qk_head_dim, in_dim]
        W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1))  # q_nope → √γ
        W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
                                                         1))  # q_pe   → γ

    elif info.kind == 'wkv_b':
        qk_nope = info.qk_nope_head_dim
        kv_stride = qk_nope + info.v_head_dim
        W_3d = W.view(-1, kv_stride, W.shape[1])  # [H, kv_stride, in_dim]
        W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1))  # k_nope → √γ
        # v rows: not touched (k_R shared rotary unchanged)