Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
File size: 10,351 Bytes
64a7ea9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | # SPDX-License-Identifier: Apache-2.0
# MegaBlocks C++ Optimized CPU MoE
"""
C++ accelerated MoE with brgemm optimization for Intel AMX.
Direct replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
"""
import torch
from typing import Optional
from .cpu_fused_moe import route_tokens_cpu
from ._ops import ops
def _to_local_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert DTensor to local torch.Tensor if needed for custom ops compatibility."""
if tensor is None:
return None
# Check if it's a DTensor by looking for the to_local() method
if hasattr(tensor, "to_local"):
return tensor.to_local()
return tensor
def fused_moe_cpp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_int8_w8a8: bool = False,
use_fp8_w8a16: bool = False,
use_mxfp4: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
limit: Optional[float] = None,
is_vnni: bool = False,
) -> torch.Tensor:
"""
C++ Fused MoE with brgemm optimization (sglang compatible interface).
Uses at::native::cpublas::brgemm for efficient batch GEMM on Intel CPUs.
Supports both silu_and_mul (standard SwiGLU) and swigluoai (GptOss) activations.
Args:
hidden_states: Input tensor [M, K]
w1: Gate and up projections [E, 2N, K]
w2: Down projection [E, K, N]
topk_weights: Expert weights [M, topk]
topk_ids: Expert indices [M, topk]
inplace: Whether to use hidden_states as output
use_int8_w8a8: Use int8 quantization
use_fp8_w8a16: Use fp8 quantization
use_mxfp4: Use mxfp4 quantization
w1_scale, w2_scale: Quantization scales
block_size: Block size for fp8
a1_scale, a2_scale: Activation scales
w1_bias, w2_bias: Optional biases
alpha: swigluoai alpha parameter (set to enable swiglu)
limit: swigluoai limit parameter (set to enable swiglu)
is_vnni: Whether w1/w2 are already in VNNI packed format
"""
# MXFP4/FP8 kernels only support bf16, convert if needed
orig_dtype = hidden_states.dtype
need_convert = ((use_mxfp4 or use_fp8_w8a16) and orig_dtype != torch.bfloat16) or orig_dtype == torch.float32
if need_convert:
hidden_states = hidden_states.to(torch.bfloat16)
# bias must match hidden_states dtype
if w1_bias is not None:
w1_bias = w1_bias.to(hidden_states.dtype)
if w2_bias is not None:
w2_bias = w2_bias.to(hidden_states.dtype)
# Convert DTensor to local tensor for custom ops compatibility (TP mode)
hidden_states = _to_local_tensor(hidden_states)
w1 = _to_local_tensor(w1)
w2 = _to_local_tensor(w2)
topk_weights = _to_local_tensor(topk_weights)
topk_ids = _to_local_tensor(topk_ids)
w1_scale = _to_local_tensor(w1_scale)
w2_scale = _to_local_tensor(w2_scale)
a1_scale = _to_local_tensor(a1_scale)
a2_scale = _to_local_tensor(a2_scale)
w1_bias = _to_local_tensor(w1_bias)
w2_bias = _to_local_tensor(w2_bias)
output = ops.fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids,
inplace, use_int8_w8a8, use_fp8_w8a16, use_mxfp4,
w1_scale, w2_scale, block_size, a1_scale, a2_scale,
w1_bias, w2_bias, alpha, limit, is_vnni
)
# Convert back to original dtype if needed
if need_convert:
output = output.to(orig_dtype)
return output
class CPUMegaBlocksMoeMLP(torch.nn.Module):
"""
C++ optimized MoE MLP using brgemm.
Drop-in replacement for cpu_fused_moe.MegaBlocksMoeMLP with better performance.
Usage in transformers:
# Will be used via @use_kernel_forward_from_hub decorator
"""
can_torch_compile: bool = True
def forward(self, x: torch.Tensor) -> tuple:
"""
Forward pass through the MoE layer using C++ kernel.
Args:
x: Input tensor [batch_size, seq_len, hidden_size]
Returns:
Tuple of (output, expert_weights)
"""
# Optimization for GPT-OSS model
if getattr(self, "use_mxfp4", None) is None:
self.use_mxfp4 = False
w1_scale = None
w2_scale = None
if (
not getattr(self, "packed_scales", False)
and hasattr(self.experts, "gate_up_proj")
and getattr(self.experts, "gate_up_proj_precision_config", None) is not None
):
# convert scales
data_1 = ops.convert_scale_packed(self.experts.gate_up_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
data_2 = ops.convert_scale_packed(self.experts.down_proj_precision_config.weight_scale.data.transpose(-1, -2).contiguous())
self.experts.gate_up_proj_precision_config.weight_scale.storage.data = data_1
self.experts.down_proj_precision_config.weight_scale.storage.data = data_2
self.packed_scales = True
self.use_mxfp4 = True
if not getattr(self, "packed_weight", False) and hasattr(
self.experts, "gate_up_proj"
):
# convert weights
data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
if self.use_mxfp4:
self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
else:
# convert_weight_packed only supports bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).
data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
data_2 = data_2.to(torch.bfloat16) if data_2.dtype == torch.float32 else data_2
self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
self.experts.down_proj.data = ops.convert_weight_packed(data_2)
# C++ kernel does not support float32.
dtype = torch.bfloat16 if x.dtype == torch.float32 else x.dtype
if getattr(self.experts, "gate_up_proj_bias", None) is not None:
self.experts.gate_up_proj_bias.data = self.experts.gate_up_proj_bias.data.to(dtype)
if getattr(self.experts, "down_proj_bias", None) is not None:
self.experts.down_proj_bias.data = self.experts.down_proj_bias.data.to(dtype)
self.packed_weight = True
# Get MoE parameters
moe_top_k = getattr(self.router, "top_k", 4)
moe_num_experts = getattr(self.experts, "num_experts", 128)
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
# Detect activation type
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
activation = "swigluoai"
alpha = self.experts.alpha
limit = self.experts.limit
else:
activation = getattr(self.experts, "activation", "silu")
alpha = 1.702
limit = 7.0
# Get weight tensors
if hasattr(self.experts, "gate_up_proj"):
w1 = self.experts.gate_up_proj
elif hasattr(self.experts, "w1"):
w1 = self.experts.w1
w3 = getattr(self.experts, "w3", None)
if w3 is not None:
w1 = torch.cat([w1, w3], dim=-1)
else:
raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
if hasattr(self.experts, "down_proj"):
w2 = self.experts.down_proj
elif hasattr(self.experts, "w2"):
w2 = self.experts.w2
else:
raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
# Get optional bias tensors
w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
w2_bias = getattr(self.experts, "down_proj_bias", None)
w1_bias = w1_bias if w1_bias is None else w1_bias.data
w2_bias = w2_bias if w2_bias is None else w2_bias.data
if self.use_mxfp4:
w1_scale = self.experts.gate_up_proj_precision_config.weight_scale.data
w2_scale = self.experts.down_proj_precision_config.weight_scale.data
# Store original shape
in_shape = x.size()
# Route tokens to experts (Python implementation is fast enough)
logits, expert_weights, expert_indices = route_tokens_cpu(
x,
self.router.weight,
getattr(self.router, "bias", None),
moe_top_k,
moe_num_experts,
moe_normalize_expert_weights,
)
# Flatten input
x_flat = x.view(-1, x.shape[-1])
# Determine alpha/limit for swiglu activation
use_alpha = alpha if activation == "swigluoai" else None
use_limit = limit if activation == "swigluoai" else None
# Call C++ optimized kernel
output = fused_moe_cpp(
hidden_states=x_flat,
w1=w1.data,
w2=w2.data,
topk_weights=expert_weights,
topk_ids=expert_indices.to(torch.int32),
inplace=False,
use_int8_w8a8=False,
use_fp8_w8a16=False,
use_mxfp4=self.use_mxfp4,
w1_scale=w1_scale,
w2_scale=w2_scale,
block_size=None,
a1_scale=None,
a2_scale=None,
w1_bias=w1_bias,
w2_bias=w2_bias,
alpha=use_alpha,
limit=use_limit,
is_vnni=getattr(self, "packed_weight", False),
)
# Restore original shape
output = output.view(in_shape)
return output, expert_weights
__all__ = ["fused_moe_cpp", "MegaBlocksMoeMLP"]
|