Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
| # SPDX-License-Identifier: Apache-2.0 | |
| # MegaBlocks XPU Fused MoE Implementation | |
| import os | |
| import torch | |
| from ._ops import ops | |
| # Install meta kernels for torch.compile compatibility | |
| def _install_xpu_meta_kernels(): | |
| """Install meta kernels for XPU MoE operations to support torch.compile""" | |
| # Patch cutlass_grouped_gemm_interface | |
| if hasattr(ops, "cutlass_grouped_gemm_interface"): | |
| original_gemm = ops.cutlass_grouped_gemm_interface | |
| def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, | |
| expert_first_token_offset, N, K, num_experts, | |
| is_B_int4, is_B_mxfp4): | |
| if torch.compiler.is_compiling(): | |
| # Meta implementation - ptr_D is the output, return it | |
| return ptr_D | |
| return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, | |
| expert_first_token_offset, N, K, num_experts, | |
| is_B_int4, is_B_mxfp4) | |
| ops.cutlass_grouped_gemm_interface = gemm_with_meta | |
| # Patch fused_moe_prologue | |
| if hasattr(ops, "fused_moe_prologue"): | |
| original_prologue = ops.fused_moe_prologue | |
| def prologue_with_meta(input, token_selected_experts, token_final_scales, | |
| workspace, hidden_size, inter_size, num_experts_on_rank): | |
| if torch.compiler.is_compiling(): | |
| # Meta implementation - this op modifies workspace in-place | |
| return None | |
| return original_prologue(input, token_selected_experts, token_final_scales, | |
| workspace, hidden_size, inter_size, num_experts_on_rank) | |
| ops.fused_moe_prologue = prologue_with_meta | |
| # Patch moe_gather | |
| if hasattr(ops, "moe_gather"): | |
| original_gather = ops.moe_gather | |
| def gather_with_meta(output, moe_output, topk_weights, | |
| unpermuted_row_to_permuted_row, num_experts): | |
| if torch.compiler.is_compiling(): | |
| # Meta implementation - output is modified in-place | |
| return None | |
| return original_gather(output, moe_output, topk_weights, | |
| unpermuted_row_to_permuted_row, num_experts) | |
| ops.moe_gather = gather_with_meta | |
| # Patch activation ops | |
| for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", | |
| "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu", | |
| "swigluoai_and_mul"]: | |
| if hasattr(ops, act_name): | |
| original_act = getattr(ops, act_name) | |
| def make_act_wrapper(orig_fn): | |
| def act_with_meta(*args, **kwargs): | |
| if torch.compiler.is_compiling(): | |
| # Meta implementation - in-place ops, return None | |
| return None | |
| return orig_fn(*args, **kwargs) | |
| return act_with_meta | |
| setattr(ops, act_name, make_act_wrapper(original_act)) | |
| # Install meta kernels on module load | |
| _install_xpu_meta_kernels() | |
| # default | |
| def cutlass_grouped_gemm(input_A, input_B, bias, output, expert_token_count, n, | |
| k, num_experts): | |
| # expert_token_count_ = torch.tensor(expert_token_count, | |
| # dtype=torch.int64, | |
| # device=input_A.device) | |
| # if bias is not None: | |
| # bias = bias.repeat_interleave(expert_token_count_, dim=0).float() | |
| def exclusive_prefix_sum(arr): | |
| prefix = [0] | |
| for i, x in enumerate(arr): | |
| prefix.append(prefix[-1] + x) | |
| return prefix | |
| expert_offset = torch.tensor(exclusive_prefix_sum(expert_token_count), | |
| dtype=torch.int64, | |
| device="xpu") | |
| ops.cutlass_grouped_gemm_interface( | |
| ptr_A=input_A, | |
| ptr_B=input_B, | |
| ptr_scales=None, | |
| ptr_bias=bias, | |
| ptr_D=output, | |
| expert_first_token_offset=expert_offset, | |
| N=n, | |
| K=k, | |
| num_experts=num_experts, | |
| is_B_int4=False, | |
| is_B_mxfp4=False) | |
| def cutlass_grouped_gemm_xe2(input_A, input_B, scales, bias, output, | |
| num_rows_per_expert, n, k, num_experts, is_B_int4, | |
| is_B_mxfp4): | |
| expert_first_token_offset = torch.cat([ | |
| torch.tensor([0], | |
| dtype=num_rows_per_expert.dtype, | |
| device=num_rows_per_expert.device), | |
| torch.cumsum(num_rows_per_expert, dim=0) | |
| ]).to(torch.int64) | |
| ops.cutlass_grouped_gemm_interface( | |
| ptr_A=input_A, | |
| ptr_B=input_B, | |
| ptr_scales=scales, | |
| ptr_bias=bias, | |
| ptr_D=output, | |
| expert_first_token_offset=expert_first_token_offset, | |
| N=n, | |
| K=k, | |
| num_experts=num_experts, | |
| is_B_int4=is_B_int4, | |
| is_B_mxfp4=is_B_mxfp4) | |
| def ceilDiv(a, b): | |
| return (a + b - 1) // b | |
| def compute_num_tokens_per_block(num_tokens, num_experts_per_node): | |
| for num_tokens_per_block in [32, 64, 128, 256, 512, 1024]: | |
| num_blocks_per_seq = ceilDiv(num_tokens, num_tokens_per_block) | |
| if num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block: | |
| return num_tokens_per_block | |
| return 1024 | |
| def implement_zp(qweight): | |
| # change u4 to s4 to avoid zero point in gemm kernel | |
| # only support default zero point now | |
| assert qweight.dtype == torch.uint8, "Input tensor must be uint8" | |
| high_u4 = (qweight >> 4) & 0x0F | |
| low_u4 = qweight & 0x0F | |
| high_s8 = high_u4.to(torch.int8) | |
| low_s8 = low_u4.to(torch.int8) | |
| high_s8 = high_s8 - 8 | |
| low_s8 = low_s8 - 8 | |
| def pack_compact(a, b): | |
| def process_number(x): | |
| sign = (x < 0).to(torch.uint8) | |
| abs_low3 = (x.view(torch.uint8) & 0x7).to(torch.uint8) | |
| return (sign << 3) | abs_low3 | |
| packed_a = process_number(a) | |
| packed_b = process_number(b) | |
| return (packed_a << 4) | packed_b | |
| result = pack_compact(high_s8, low_s8) | |
| return result | |
| def xpu_fused_moe(hidden_states, | |
| w13, | |
| w13_scales, | |
| w13_bias, | |
| w2, | |
| w2_scales, | |
| w2_bias, | |
| topk_weights, | |
| topk_ids, | |
| n_experts_per_token, | |
| activation, | |
| num_experts, | |
| is_fp8=False, | |
| is_int4=False, | |
| is_mxfp4=False): | |
| ''' | |
| hidden_states: [num_rows, hidden_size] | |
| w13: [num_experts, 2*inter_size, hidden_size] | |
| w13_scales: | |
| None for bf16/fp16 | |
| or [num_experts] for fp8 | |
| or [num_experts, 2*inter_size, hidden_size // group_size] for 4bits | |
| w13_bias: [num_experts, 2*inter_size] or None | |
| w2: [num_experts, hidden_size, inter_size] | |
| w2_scales: | |
| None for bf16/fp16 | |
| or [num_experts] for fp8 | |
| or [num_experts, hidden_size, inter_size // group_size] for 4bits | |
| w2_bias: [num_experts, hidden_size] or None | |
| topk_weights: [num_rows, topk] | |
| topk_ids: [num_rows, topk] | |
| n_experts_per_token: int | |
| activation: str | |
| num_experts: int | |
| is_int4: bool | |
| is_mxfp4: bool | |
| ''' | |
| output = torch.empty_like(hidden_states) | |
| num_rows, hidden_size = list(hidden_states.shape) | |
| dim_last = w13.shape[-1] | |
| dim_second_last = w13.shape[-2] | |
| # w13 is combined gate+up weights, so one dimension is 2*inter_size | |
| # Determine which dimension is hidden_size and which is 2*inter_size | |
| if dim_second_last == hidden_size: | |
| # w13 is [E, hidden_size, 2*inter_size] - standard layout | |
| inter_size = dim_last // 2 | |
| needs_transpose = False | |
| else: | |
| # w13 is [E, 2*inter_size, hidden_size] - needs transpose | |
| inter_size = dim_second_last // 2 | |
| needs_transpose = True | |
| assert w13.is_contiguous() and w2.is_contiguous() | |
| # 4bits support [E, N, K] | |
| # other types [E, K, N] | |
| if not is_int4 and not is_mxfp4: | |
| if not hasattr(w13, 'xpu_fused_moe'): | |
| if needs_transpose: | |
| w13.data = w13.transpose(-1, -2).contiguous() | |
| w2.data = w2.transpose(-1, -2).contiguous() | |
| w13.xpu_fused_moe = True | |
| w13.inter_size = inter_size | |
| else: | |
| inter_size = w13.inter_size | |
| if is_int4 and not hasattr(w13, 'xpu_fused_moe'): | |
| w13_tmp = torch.empty_like(w13) | |
| w2_tmp = torch.empty_like(w2) | |
| for i in range(num_experts): | |
| w13_tmp[i] = implement_zp(w13[i]) | |
| w2_tmp[i] = implement_zp(w2[i]) | |
| w13_tmp = w13_tmp.contiguous() | |
| w2_tmp = w2_tmp.contiguous() | |
| w13.data = w13_tmp | |
| w2.data = w2_tmp | |
| w13.xpu_fused_moe = True | |
| # TODO: will all integrated in Cpp func. Temporary expose before gemm fusion | |
| num_experts_per_node = num_experts | |
| experts_per_token = n_experts_per_token | |
| num_moe_inputs = n_experts_per_token * num_rows | |
| permuted_elems = num_moe_inputs * hidden_size | |
| # interbuf_elems = num_moe_inputs * inter_size | |
| permuted_row_to_unpermuted_row_size = num_moe_inputs * 4 | |
| permuted_token_selected_experts_size = num_moe_inputs * 4 | |
| src_to_dest_map_size = experts_per_token * num_rows * 4 | |
| expert_first_token_offset_size = (num_experts_per_node + 1) * 8 | |
| num_tokens_per_block = compute_num_tokens_per_block( | |
| num_rows, num_experts_per_node) | |
| num_blocks_per_seq = ceilDiv(num_rows, num_tokens_per_block) | |
| blocked_expert_counts_size = num_experts_per_node * num_blocks_per_seq * 4 | |
| blocked_expert_counts_cumsum_size = blocked_expert_counts_size | |
| blocked_row_to_unpermuted_row_size = num_experts_per_node * num_rows * 4 | |
| permuted_data_size = permuted_elems * hidden_states.element_size() | |
| permuted_token_final_scales_size = num_moe_inputs * 4 | |
| ws_map = {} | |
| map_offset = 0 | |
| def config_ws(name, size): | |
| nonlocal map_offset | |
| if size % 256 != 0: | |
| size += 256 - size % 256 | |
| ws_map[name] = (size, map_offset) | |
| map_offset += size | |
| config_ws("permuted_row_to_unpermuted_row", | |
| permuted_row_to_unpermuted_row_size) | |
| config_ws("permuted_token_selected_experts", | |
| permuted_token_selected_experts_size) | |
| config_ws("unpermuted_row_to_permuted_row", src_to_dest_map_size) | |
| config_ws("blocked_expert_counts", blocked_expert_counts_size) | |
| config_ws("blocked_expert_counts_cumsum", | |
| blocked_expert_counts_cumsum_size) | |
| config_ws("blocked_row_to_unpermuted_row", | |
| blocked_row_to_unpermuted_row_size) | |
| config_ws("expert_first_token_offset", expert_first_token_offset_size) | |
| config_ws("permuted_token_final_scales", permuted_token_final_scales_size) | |
| config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size) | |
| workspace = torch.zeros(map_offset, | |
| dtype=torch.uint8, | |
| device=hidden_states.device) | |
| if topk_ids.dtype == torch.int32: | |
| topk_ids = topk_ids.to(torch.int64) | |
| ops.fused_moe_prologue( | |
| input=hidden_states, | |
| token_selected_experts=topk_ids, | |
| token_final_scales=topk_weights, | |
| workspace=workspace, | |
| hidden_size=hidden_size, | |
| inter_size=inter_size, | |
| num_experts_on_rank=num_experts_per_node) | |
| expert_first_token_offset = workspace[ | |
| ws_map["expert_first_token_offset"][1]: | |
| ws_map["expert_first_token_offset"][1] + | |
| expert_first_token_offset_size].view(torch.int64) | |
| unpermuted_row_to_permuted_row = workspace[ | |
| ws_map["unpermuted_row_to_permuted_row"][1]: | |
| ws_map["unpermuted_row_to_permuted_row"][1] + | |
| src_to_dest_map_size].view(torch.int32) | |
| gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]: | |
| ws_map["overlapped_gemm1_gemm2_inputs"][1] + | |
| permuted_data_size].view(hidden_states.dtype).view( | |
| num_moe_inputs, hidden_size) | |
| # permuted_token_final_scales = workspace[ | |
| # ws_map["permuted_token_final_scales"][1]: | |
| # ws_map["permuted_token_final_scales"][1] + | |
| # permuted_token_final_scales_size].view(torch.float) | |
| gemm1_output = torch.empty((num_moe_inputs, 2 * inter_size), | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device) | |
| ########### gemm1 ################## | |
| input_B = w13 | |
| if not is_fp8 and not is_int4 and not is_mxfp4: | |
| ops.cutlass_grouped_gemm_interface( | |
| ptr_A=gemm1_input, | |
| ptr_B=input_B, | |
| ptr_scales=None, | |
| ptr_bias=w13_bias, | |
| ptr_D=gemm1_output, | |
| expert_first_token_offset=expert_first_token_offset, | |
| N=2 * inter_size, | |
| K=hidden_size, | |
| num_experts=num_experts_per_node, | |
| is_B_int4=is_int4, | |
| is_B_mxfp4=is_mxfp4) | |
| else: | |
| ops.cutlass_grouped_gemm_interface( | |
| ptr_A=gemm1_input, | |
| ptr_B=input_B, | |
| ptr_scales=w13_scales, | |
| ptr_bias=w13_bias, | |
| ptr_D=gemm1_output, | |
| expert_first_token_offset=expert_first_token_offset, | |
| N=2 * inter_size, | |
| K=hidden_size, | |
| num_experts=num_experts_per_node, | |
| is_B_int4=is_int4, | |
| is_B_mxfp4=is_mxfp4) | |
| # act | |
| act_output = torch.empty((num_moe_inputs, inter_size), | |
| dtype=gemm1_output.dtype, | |
| device=gemm1_output.device) | |
| if activation == "silu": | |
| ops.silu_and_mul(act_output, gemm1_output) | |
| elif activation == "gelu": | |
| ops.gelu_and_mul(act_output, gemm1_output) | |
| elif activation == "swigluoai": | |
| ops.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0) | |
| else: | |
| raise ValueError(f"Unsupported FusedMoe activation: {activation}.") | |
| ########### gemm2 ################## | |
| input_A = act_output.contiguous() | |
| input_B = w2 | |
| gemm2_output = torch.empty((num_moe_inputs, hidden_size), | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device) | |
| if not is_fp8 and not is_int4 and not is_mxfp4: | |
| ops.cutlass_grouped_gemm_interface( | |
| ptr_A=input_A, | |
| ptr_B=input_B, | |
| ptr_scales=None, | |
| ptr_bias=w2_bias, | |
| ptr_D=gemm2_output, | |
| expert_first_token_offset=expert_first_token_offset, | |
| N=hidden_size, | |
| K=inter_size, | |
| num_experts=num_experts_per_node, | |
| is_B_int4=is_int4, | |
| is_B_mxfp4=is_mxfp4) | |
| else: | |
| ops.cutlass_grouped_gemm_interface( | |
| ptr_A=input_A, | |
| ptr_B=input_B, | |
| ptr_scales=w2_scales, | |
| ptr_bias=w2_bias, | |
| ptr_D=gemm2_output, | |
| expert_first_token_offset=expert_first_token_offset, | |
| N=hidden_size, | |
| K=inter_size, | |
| num_experts=num_experts_per_node, | |
| is_B_int4=is_int4, | |
| is_B_mxfp4=is_mxfp4) | |
| ops.moe_gather(output, gemm2_output, topk_weights, | |
| unpermuted_row_to_permuted_row, | |
| num_experts_per_node) | |
| return output | |
| def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: | |
| """Apply jitter to the input tensor for regularization.""" | |
| low = 1.0 - moe_jitter_eps | |
| high = 1.0 + moe_jitter_eps | |
| noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) | |
| return x * (low + noise * (high - low)) | |
| def compute_top_k(scores: torch.Tensor, moe_top_k: int): | |
| """Compute the top-k scores from the logits.""" | |
| if moe_top_k == 1: | |
| return scores.max(dim=-1, keepdim=True) | |
| return torch.topk(scores, moe_top_k, dim=-1) | |
| def route_tokens_xpu( | |
| x: torch.Tensor, | |
| router_weight: torch.Tensor, | |
| router_bias: torch.Tensor, | |
| moe_top_k: int, | |
| moe_num_experts: int, | |
| moe_jitter_eps: float = None, | |
| moe_normalize_expert_weights: int = None, | |
| training: bool = False, | |
| ) -> tuple: | |
| """Route tokens to experts and compute expert weights and indices (XPU version).""" | |
| if training and moe_jitter_eps is not None: | |
| x = apply_jitter(x, moe_jitter_eps) | |
| x_flat = x.view(-1, x.shape[-1]) | |
| logits = torch.nn.functional.linear(x_flat, router_weight, router_bias) | |
| expert_weights, expert_indices = compute_top_k(logits, moe_top_k) | |
| expert_weights = expert_weights.softmax(dim=-1) | |
| if moe_normalize_expert_weights is not None: | |
| expert_weights = expert_weights / torch.norm( | |
| expert_weights, | |
| p=moe_normalize_expert_weights, | |
| dim=-1, | |
| keepdim=True, | |
| ) | |
| return logits, expert_weights, expert_indices | |
| class MegaBlocksMoeMLP(torch.nn.Module): | |
| can_torch_compile: bool = True | |
| def forward(self, x: torch.Tensor) -> tuple: | |
| """ | |
| Forward pass through the MoE layer. | |
| Args: | |
| x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size] | |
| Returns: | |
| Tuple of (output, expert_weights) where: | |
| - output: Tensor of same shape as input | |
| - expert_weights: Expert weights for each token [tokens, top_k] | |
| """ | |
| # Get MoE parameters from the wrapped modules | |
| moe_top_k = getattr(self.router, "top_k", 4) | |
| moe_num_experts = getattr(self.experts, "num_experts", 128) | |
| moe_jitter_eps = getattr(self.experts, "jitter_eps", None) | |
| moe_normalize_expert_weights = getattr( | |
| self.experts, "normalize_expert_weights", None | |
| ) | |
| # Detect activation type - check for GptOss-style swigluoai activation | |
| # GptOssExperts has alpha and limit attributes for swigluoai | |
| if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"): | |
| activation = "swigluoai" | |
| else: | |
| activation = getattr(self.experts, "activation", "silu") | |
| # Get weight tensors - support different naming conventions | |
| if hasattr(self.experts, "gate_up_proj"): | |
| w13 = self.experts.gate_up_proj | |
| # NOTE: swigluoai_and_mul kernel expects interleaved layout [g0,u0,g1,u1,...] | |
| # which matches GptOss's gate_up_proj layout, so no conversion needed. | |
| elif hasattr(self.experts, "w1"): | |
| # Combine w1 and w3 if stored separately | |
| w1 = self.experts.w1 | |
| w3 = getattr(self.experts, "w3", None) | |
| if w3 is not None: | |
| w13 = torch.cat([w1, w3], dim=-2) | |
| else: | |
| w13 = w1 | |
| else: | |
| raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute") | |
| if hasattr(self.experts, "down_proj"): | |
| w2 = self.experts.down_proj | |
| elif hasattr(self.experts, "w2"): | |
| w2 = self.experts.w2 | |
| else: | |
| raise AttributeError("experts module must have 'down_proj' or 'w2' attribute") | |
| # Get optional bias tensors | |
| w13_bias = getattr(self.experts, "gate_up_proj_bias", None) | |
| w2_bias = getattr(self.experts, "down_proj_bias", None) | |
| # Get quantization info | |
| is_fp8 = getattr(self.experts, "is_fp8", False) | |
| is_int4 = getattr(self.experts, "is_int4", False) | |
| is_mxfp4 = getattr(self.experts, "is_mxfp4", False) | |
| w13_scales = getattr(self.experts, "gate_up_proj_scales", None) | |
| w2_scales = getattr(self.experts, "down_proj_scales", None) | |
| # Store original shape | |
| in_shape = x.size() | |
| # Route tokens to experts | |
| logits, expert_weights, expert_indices = route_tokens_xpu( | |
| x, | |
| self.router.weight, | |
| getattr(self.router, "bias", None), | |
| moe_top_k, | |
| moe_num_experts, | |
| moe_jitter_eps, | |
| moe_normalize_expert_weights, | |
| self.training, | |
| ) | |
| # Reshape input for fused MoE | |
| x_flat = x.view(-1, x.shape[-1]) | |
| # Call XPU fused MoE kernel | |
| output = xpu_fused_moe( | |
| hidden_states=x_flat, | |
| w13=w13, | |
| w13_scales=w13_scales, | |
| w13_bias=w13_bias, | |
| w2=w2, | |
| w2_scales=w2_scales, | |
| w2_bias=w2_bias, | |
| topk_weights=expert_weights.float(), | |
| topk_ids=expert_indices, | |
| n_experts_per_token=moe_top_k, | |
| activation=activation, | |
| num_experts=moe_num_experts, | |
| is_fp8=is_fp8, | |
| is_int4=is_int4, | |
| is_mxfp4=is_mxfp4, | |
| ) | |
| # Restore original shape | |
| output = output.view(in_shape) | |
| return output, expert_weights | |
| # Export classes and functions | |
| __all__ = [ | |
| "MegaBlocksMoeMLP", | |
| "xpu_fused_moe", | |
| "cutlass_grouped_gemm", | |
| "cutlass_grouped_gemm_xe2", | |
| ] |