Kernels
danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
64a7ea9 verified
# SPDX-License-Identifier: Apache-2.0
# MegaBlocks C++ Optimized CPU MoE
"""
C++ accelerated MoE with brgemm optimization for Intel AMX.
Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
"""
import torch
from typing import Optional
from .cpu_fused_moe import route_tokens_cpu
from ._ops import ops
def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
if tensor is None:
return None
# Check if it's a DTensor by looking for the to_local() method
if hasattr(tensor, "to_local"):
return tensor.to_local()
return tensor
def fused_moe_cpp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_int8_w8a8: bool = False,
use_fp8_w8a16: bool = False,
use_mxfp4: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
limit: Optional[float] = None,
is_vnni: bool = False,
) -> torch.Tensor:
"""
C++ Fused MoE with brgemm optimization (sglang compatible interface).
Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
Args:
hidden_states: Input tensor [M, K]
w1: Gate and up projections [E, 2N, K]
w2: Down projection [E, K, N]
topk_weights: Expert weights [M, topk]
topk_ids: Expert indices [M, topk]
inplace: Whether to use hidden_states as output
use_int8_w8a8: Use int8 quantization
use_fp8_w8a16: Use fp8 quantization
use_mxfp4: Use mxfp4 quantization
w1_scale, w2_scale: Quantization scales
block_size: Block size for fp8
a1_scale, a2_scale: Activation scales
w1_bias, w2_bias: Optional biases
alpha: swigluoai alpha parameter (set to enable swiglu)
limit: swigluoai limit parameter (set to enable swiglu)
is_vnni: Whether w1/w2 are already in VNNI packed format
"""
# MXFP4/FP8 kernels only support bf16, convert if needed
orig_dtype = hidden_states.dtype
need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
if need_convert:
hidden_states = hidden_states.to(torch.bfloat16)
# bias must match hidden_states dtype
if w1_bias is not None:
w1_bias = w1_bias.to(hidden_states.dtype)
if w2_bias is not None:
w2_bias = w2_bias.to(hidden_states.dtype)
# Convert DTensor to local tensor for custom ops compatibility (TP mode)
hidden_states = _to_local_tensor(hidden_states)
w1 = _to_local_tensor(w1)
w2 = _to_local_tensor(w2)
topk_weights = _to_local_tensor(topk_weights)
topk_ids = _to_local_tensor(topk_ids)
w1_scale = _to_local_tensor(w1_scale)
w2_scale = _to_local_tensor(w2_scale)
a1_scale = _to_local_tensor(a1_scale)
a2_scale = _to_local_tensor(a2_scale)
w1_bias = _to_local_tensor(w1_bias)
w2_bias = _to_local_tensor(w2_bias)
output = ops.fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids,
inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
w1_scale, w2_scale, block_size, a1_scale, a2_scale,
w1_bias, w2_bias, alpha, limit, is_vnni
)
# Convert back to original dtype if needed
if need_convert:
output = output.to(orig_dtype)
return output
class CPUMegaBlocksMoeMLP(torch.nn.Module):
"""
C++ optimized MoE MLP using brgemm.
Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
Usage in transformers:
# Will be used via @use_kernel_forward_from_hub decorator
"""
can_torch_compile: bool = True
def forward(self, x: torch.Tensor) -> tuple:
"""
Forward pass through the MoE layer using C++ kernel.
Args:
x: Input tensor [batch_size, seq_len, hidden_size]
Returns:
Tuple of (output, expert_weights)
"""
# Optimization for GPT-OSS model
if getattr(self, "use_mxfp4", None) is None:
self.use_mxfp4 = False
w1_scale = None
w2_scale = None
if (
not getattr(self, "packed_scales", False)
and hasattr(self.experts, "gate_up_proj")
and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
):
# convert scales
data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
self.packed_scales = True
self.use_mxfp4 = True
if not getattr(self, "packed_weight", False) and hasattr(
self.experts, "gate_up_proj"
):
# convert weights
data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
if self.use_mxfp4:
self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
else:
# convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
self.experts.down_proj.data = ops.convert_weight_packed(data_2)
# C++ kernel does not support float32.
dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
if getattr(self.experts, "gate_up_proj_bias", None) is not None:
self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
if getattr(self.experts, "down_proj_bias", None) is not None:
self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
self.packed_weight = True
# Get MoE parameters
moe_top_k = getattr(self.router, "top_k", 4)
moe_num_experts = getattr(self.experts, "num_experts", 128)
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
# Detect activation type
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
activation = "swigluoai"
alpha = self.experts.alpha
limit = self.experts.limit
else:
activation = getattr(self.experts, "activation", "silu")
alpha = 1.702
limit = 7.0
# Get weight tensors
if hasattr(self.experts, "gate_up_proj"):
w1 = self.experts.gate_up_proj
elif hasattr(self.experts, "w1"):
w1 = self.experts.w1
w3 = getattr(self.experts, "w3", None)
if w3 is not None:
w1 = torch.cat([w1, w3], dim=-1)
else:
raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
if hasattr(self.experts, "down_proj"):
w2 = self.experts.down_proj
elif hasattr(self.experts, "w2"):
w2 = self.experts.w2
else:
raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
# Get optional bias tensors
w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
w2_bias = getattr(self.experts, "down_proj_bias", None)
w1_bias = w1_bias if w1_bias is None else w1_bias.data
w2_bias = w2_bias if w2_bias is None else w2_bias.data
if self.use_mxfp4:
w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
w2_scale = self.experts.down_proj_precision_config.weight_scale.data
# Store original shape
in_shape = x.size()
# Route tokens to experts (Python implementation is fast enough)
logits, expert_weights, expert_indices = route_tokens_cpu(
x,
self.router.weight,
getattr(self.router, "bias", None),
moe_top_k,
moe_num_experts,
moe_normalize_expert_weights,
)
# Flatten input
x_flat = x.view(-1, x.shape[-1])
# Determine alpha/limit for swiglu activation
use_alpha = alpha if activation == "swigluoai" else None
use_limit = limit if activation == "swigluoai" else None
# Call C++ optimized kernel
output = fused_moe_cpp(
hidden_states=x_flat,
w1=w1.data,
w2=w2.data,
topk_weights=expert_weights,
topk_ids=expert_indices.to(torch.int32),
inplace=False,
use_int8_w8a8=False,
use_fp8_w8a16=False,
use_mxfp4=self.use_mxfp4,
w1_scale=w1_scale,
w2_scale=w2_scale,
block_size=None,
a1_scale=None,
a2_scale=None,
w1_bias=w1_bias,
w2_bias=w2_bias,
alpha=use_alpha,
limit=use_limit,
is_vnni=getattr(self, "packed_weight", False),
)
# Restore original shape
output = output.view(in_shape)
return output, expert_weights
__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]