Kernels
File size: 10,351 Bytes
64a7ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# SPDX-License-Identifier: Apache-2.0
# MegaBlocks C++ Optimized CPU MoE

"""
C++ accelerated MoE with brgemm optimization for Intel AMX.
Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
"""

import torch
from typing import Optional
from .cpu_fused_moe import route_tokens_cpu
from ._ops import ops


def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
    """Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
    if tensor is None:
        return None
    # Check if it's a DTensor by looking for the to_local() method
    if hasattr(tensor, "to_local"):
        return tensor.to_local()
    return tensor


def fused_moe_cpp(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    inplace: bool = False,
    use_int8_w8a8: bool = False,
    use_fp8_w8a16: bool = False,
    use_mxfp4: bool = False,
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
    block_size: Optional[list] = None,
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
    w1_bias: Optional[torch.Tensor] = None,
    w2_bias: Optional[torch.Tensor] = None,
    alpha: Optional[float] = None,
    limit: Optional[float] = None,
    is_vnni: bool = False,
) -> torch.Tensor:
    """
    C++ Fused MoE with brgemm optimization (sglang compatible interface).
    
    Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
    Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
    
    Args:
        hidden_states: Input tensor [M, K]
        w1: Gate and up projections [E, 2N, K]
        w2: Down projection [E, K, N]
        topk_weights: Expert weights [M, topk]
        topk_ids: Expert indices [M, topk]
        inplace: Whether to use hidden_states as output
        use_int8_w8a8: Use int8 quantization
        use_fp8_w8a16: Use fp8 quantization
        use_mxfp4: Use mxfp4 quantization
        w1_scale, w2_scale: Quantization scales
        block_size: Block size for fp8
        a1_scale, a2_scale: Activation scales
        w1_bias, w2_bias: Optional biases
        alpha: swigluoai alpha parameter (set to enable swiglu)
        limit: swigluoai limit parameter (set to enable swiglu)
        is_vnni: Whether w1/w2 are already in VNNI packed format
    """
    # MXFP4/FP8 kernels only support bf16, convert if needed
    orig_dtype = hidden_states.dtype
    need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
    if need_convert:
        hidden_states = hidden_states.to(torch.bfloat16)

    # bias must match hidden_states dtype
    if w1_bias is not None:
        w1_bias = w1_bias.to(hidden_states.dtype)
    if w2_bias is not None:
        w2_bias = w2_bias.to(hidden_states.dtype)

    # Convert DTensor to local tensor for custom ops compatibility (TP mode)
    hidden_states = _to_local_tensor(hidden_states)
    w1 = _to_local_tensor(w1)
    w2 = _to_local_tensor(w2)
    topk_weights = _to_local_tensor(topk_weights)
    topk_ids = _to_local_tensor(topk_ids)
    w1_scale = _to_local_tensor(w1_scale)
    w2_scale = _to_local_tensor(w2_scale)
    a1_scale = _to_local_tensor(a1_scale)
    a2_scale = _to_local_tensor(a2_scale)
    w1_bias = _to_local_tensor(w1_bias)
    w2_bias = _to_local_tensor(w2_bias)
    
    output = ops.fused_experts(
        hidden_states, w1, w2, topk_weights, topk_ids,
        inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
        w1_scale, w2_scale, block_size, a1_scale, a2_scale,
        w1_bias, w2_bias, alpha, limit, is_vnni
    )
    
    # Convert back to original dtype if needed
    if need_convert:
        output = output.to(orig_dtype)
    return output


class CPUMegaBlocksMoeMLP(torch.nn.Module):
    """
    C++ optimized MoE MLP using brgemm.
    Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
    
    Usage in transformers:
        # Will be used via @use_kernel_forward_from_hub decorator
    """
    can_torch_compile: bool = True

    def forward(self, x: torch.Tensor) -> tuple:
        """
        Forward pass through the MoE layer using C++ kernel.
        
        Args:
            x: Input tensor [batch_size, seq_len, hidden_size]
            
        Returns:
            Tuple of (output, expert_weights)
        """
        # Optimization for GPT-OSS model
        if getattr(self, "use_mxfp4", None) is None:
            self.use_mxfp4 = False

        w1_scale = None
        w2_scale = None

        if (
            not getattr(self, "packed_scales", False)
            and hasattr(self.experts, "gate_up_proj")
            and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
        ):
            # convert scales
            data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
            data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
            self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
            self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
            self.packed_scales = True
            self.use_mxfp4 = True

        if not getattr(self, "packed_weight", False) and hasattr(
            self.experts, "gate_up_proj"
        ):
            # convert weights
            data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
            data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
            if self.use_mxfp4:
                self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
                self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
            else:
                # convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
                data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
                data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
                self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
                self.experts.down_proj.data = ops.convert_weight_packed(data_2)

            # C++ kernel does not support float32.
            dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
            if getattr(self.experts, "gate_up_proj_bias", None) is not None:
                self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
            if getattr(self.experts, "down_proj_bias", None) is not None:
                self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)

            self.packed_weight = True

        # Get MoE parameters
        moe_top_k = getattr(self.router, "top_k", 4)
        moe_num_experts = getattr(self.experts, "num_experts", 128)
        moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
        
        # Detect activation type
        if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
            activation = "swigluoai"
            alpha = self.experts.alpha
            limit = self.experts.limit
        else:
            activation = getattr(self.experts, "activation", "silu")
            alpha = 1.702
            limit = 7.0
        
        # Get weight tensors
        if hasattr(self.experts, "gate_up_proj"):
            w1 = self.experts.gate_up_proj
        elif hasattr(self.experts, "w1"):
            w1 = self.experts.w1
            w3 = getattr(self.experts, "w3", None)
            if w3 is not None:
                w1 = torch.cat([w1, w3], dim=-1)
        else:
            raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
        
        if hasattr(self.experts, "down_proj"):
            w2 = self.experts.down_proj
        elif hasattr(self.experts, "w2"):
            w2 = self.experts.w2
        else:
            raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
        
        # Get optional bias tensors
        w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
        w2_bias = getattr(self.experts, "down_proj_bias", None)
        w1_bias = w1_bias if w1_bias is None else w1_bias.data
        w2_bias = w2_bias if w2_bias is None else w2_bias.data

        if self.use_mxfp4:
            w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
            w2_scale = self.experts.down_proj_precision_config.weight_scale.data
        
        # Store original shape
        in_shape = x.size()
        
        # Route tokens to experts (Python implementation is fast enough)
        logits, expert_weights, expert_indices = route_tokens_cpu(
            x,
            self.router.weight,
            getattr(self.router, "bias", None),
            moe_top_k,
            moe_num_experts,
            moe_normalize_expert_weights,
        )
        
        # Flatten input
        x_flat = x.view(-1, x.shape[-1])
        
        # Determine alpha/limit for swiglu activation
        use_alpha = alpha if activation == "swigluoai" else None
        use_limit = limit if activation == "swigluoai" else None
        
        # Call C++ optimized kernel
        output = fused_moe_cpp(
            hidden_states=x_flat,
            w1=w1.data,
            w2=w2.data,
            topk_weights=expert_weights,
            topk_ids=expert_indices.to(torch.int32),
            inplace=False,
            use_int8_w8a8=False,
            use_fp8_w8a16=False,
            use_mxfp4=self.use_mxfp4,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            block_size=None,
            a1_scale=None,
            a2_scale=None,
            w1_bias=w1_bias,
            w2_bias=w2_bias,
            alpha=use_alpha,
            limit=use_limit,
            is_vnni=getattr(self, "packed_weight", False),
        )
        
        # Restore original shape
        output = output.view(in_shape)
        
        return output, expert_weights


__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]