Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
| # SPDX-License-Identifier: Apache-2.0 | |
| # MegaBlocks CPU Fused MoE Implementation | |
| # | |
| # This is a pure Python/PyTorch implementation for CPU. | |
| # For better performance, consider using the C++ kernel implementation. | |
| # | |
| import torch | |
| import torch.nn.functional as F | |
| def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor, | |
| alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor: | |
| """ | |
| SwigluOAI activation function used in GptOss models. | |
| Formula: | |
| gate = clamp(gate, max=limit) | |
| up = clamp(up, -limit, limit) | |
| glu = gate * sigmoid(gate * alpha) | |
| output = (up + 1) * glu | |
| Args: | |
| gate: Gate tensor from gate projection | |
| up: Up tensor from up projection | |
| alpha: Scaling factor for sigmoid (default: 1.702) | |
| limit: Clamp limit (default: 7.0) | |
| Returns: | |
| Activated tensor | |
| """ | |
| gate = gate.clamp(max=limit) | |
| up = up.clamp(min=-limit, max=limit) | |
| glu = gate * torch.sigmoid(gate * alpha) | |
| return (up + 1) * glu | |
| def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: | |
| """ | |
| SiLU (Swish) activation with element-wise multiplication. | |
| Formula: | |
| output = silu(gate) * up | |
| Args: | |
| gate: Gate tensor | |
| up: Up tensor | |
| Returns: | |
| Activated tensor | |
| """ | |
| return F.silu(gate) * up | |
| def route_tokens_cpu( | |
| x: torch.Tensor, | |
| router_weight: torch.Tensor, | |
| router_bias: torch.Tensor | None, | |
| moe_top_k: int, | |
| moe_num_experts: int, | |
| moe_normalize_expert_weights: int | None = None, | |
| ) -> tuple: | |
| """ | |
| Route tokens to experts and compute expert weights and indices (CPU version). | |
| Args: | |
| x: Input tensor [batch, seq, hidden] or [tokens, hidden] | |
| router_weight: Router weight [num_experts, hidden] | |
| router_bias: Router bias [num_experts] or None | |
| moe_top_k: Number of experts per token | |
| moe_num_experts: Total number of experts | |
| moe_normalize_expert_weights: Normalization order or None | |
| Returns: | |
| Tuple of (logits, expert_weights, expert_indices) | |
| """ | |
| x_flat = x.view(-1, x.shape[-1]) | |
| logits = F.linear(x_flat, router_weight, router_bias) | |
| if moe_top_k == 1: | |
| expert_weights, expert_indices = logits.max(dim=-1, keepdim=True) | |
| else: | |
| expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1) | |
| expert_weights = expert_weights.softmax(dim=-1) | |
| if moe_normalize_expert_weights is not None: | |
| expert_weights = expert_weights / torch.norm( | |
| expert_weights, | |
| p=moe_normalize_expert_weights, | |
| dim=-1, | |
| keepdim=True, | |
| ) | |
| return logits, expert_weights, expert_indices | |
| def cpu_fused_moe( | |
| hidden_states: torch.Tensor, | |
| w1: torch.Tensor, | |
| w2: torch.Tensor, | |
| topk_weights: torch.Tensor, | |
| topk_ids: torch.Tensor, | |
| w1_bias: torch.Tensor | None = None, | |
| w2_bias: torch.Tensor | None = None, | |
| activation: str = "silu", | |
| alpha: float = 1.702, | |
| limit: float = 7.0, | |
| is_interleaved: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| CPU Fused MoE using PyTorch operations. | |
| This implementation processes all experts in parallel using batched operations | |
| instead of sequential for loops, which is more efficient on CPU. | |
| Args: | |
| hidden_states: [num_tokens, hidden_size] | |
| w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights | |
| w2: [num_experts, inter_size, hidden_size] - down_proj weights | |
| topk_weights: [num_tokens, topk] - routing weights | |
| topk_ids: [num_tokens, topk] - expert indices | |
| w1_bias: [num_experts, 2*inter_size] or None | |
| w2_bias: [num_experts, hidden_size] or None | |
| activation: "silu" or "swigluoai" | |
| alpha: swigluoai alpha parameter | |
| limit: swigluoai limit parameter | |
| is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss) | |
| Returns: | |
| output: [num_tokens, hidden_size] | |
| """ | |
| num_tokens, hidden_size = hidden_states.shape | |
| num_experts = w1.shape[0] | |
| inter_size = w2.shape[1] | |
| topk = topk_weights.shape[1] | |
| # Initialize output | |
| output = torch.zeros_like(hidden_states) | |
| # Build expert mask: which tokens go to which expert | |
| # expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs | |
| for expert_idx in range(num_experts): | |
| # Find tokens assigned to this expert | |
| # mask shape: [num_tokens, topk], True where topk_ids == expert_idx | |
| mask = (topk_ids == expert_idx) | |
| if not mask.any(): | |
| continue | |
| # Get token indices and topk positions | |
| token_indices, topk_positions = torch.where(mask) | |
| if len(token_indices) == 0: | |
| continue | |
| # Gather input tokens for this expert | |
| # current_hidden: [num_selected_tokens, hidden_size] | |
| current_hidden = hidden_states[token_indices] | |
| # Get weights for this expert | |
| # w1[expert_idx]: [hidden_size, 2*inter_size] | |
| # w2[expert_idx]: [inter_size, hidden_size] | |
| expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size] | |
| expert_w2 = w2[expert_idx] # [inter_size, hidden_size] | |
| # First projection: hidden @ w1 -> [num_selected, 2*inter_size] | |
| gate_up = current_hidden @ expert_w1 | |
| # Add bias if present | |
| if w1_bias is not None: | |
| gate_up = gate_up + w1_bias[expert_idx] | |
| # Split gate and up projections | |
| if is_interleaved: | |
| # GptOss uses interleaved layout: [g0, u0, g1, u1, ...] | |
| gate = gate_up[..., ::2] # [num_selected, inter_size] | |
| up = gate_up[..., 1::2] # [num_selected, inter_size] | |
| else: | |
| # Standard layout: [gate_all, up_all] | |
| gate = gate_up[..., :inter_size] | |
| up = gate_up[..., inter_size:] | |
| # Apply activation | |
| if activation == "swigluoai": | |
| activated = swigluoai_activation(gate, up, alpha, limit) | |
| else: # silu | |
| activated = silu_and_mul_activation(gate, up) | |
| # Second projection: activated @ w2 -> [num_selected, hidden_size] | |
| expert_out = activated @ expert_w2 | |
| # Add bias if present | |
| if w2_bias is not None: | |
| expert_out = expert_out + w2_bias[expert_idx] | |
| # Apply routing weights and accumulate | |
| # weights shape: [num_selected] | |
| weights = topk_weights[token_indices, topk_positions].unsqueeze(-1) | |
| weighted_out = expert_out * weights | |
| # Accumulate to output | |
| output.index_add_(0, token_indices, weighted_out.to(output.dtype)) | |
| return output | |
| class MegaBlocksMoeMLP(torch.nn.Module): | |
| """ | |
| CPU MoE MLP module that can be used as a drop-in replacement for | |
| the transformers GptOssMLP when using @use_kernel_forward_from_hub. | |
| """ | |
| can_torch_compile: bool = True | |
| def forward(self, x: torch.Tensor) -> tuple: | |
| """ | |
| Forward pass through the MoE layer. | |
| Args: | |
| x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size] | |
| Returns: | |
| Tuple of (output, expert_weights) where: | |
| - output: Tensor of same shape as input | |
| - expert_weights: Expert weights for each token [tokens, top_k] | |
| """ | |
| # Get MoE parameters from the wrapped modules | |
| moe_top_k = getattr(self.router, "top_k", 4) | |
| moe_num_experts = getattr(self.experts, "num_experts", 128) | |
| moe_normalize_expert_weights = getattr( | |
| self.experts, "normalize_expert_weights", None | |
| ) | |
| # Detect activation type | |
| if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"): | |
| activation = "swigluoai" | |
| alpha = self.experts.alpha | |
| limit = self.experts.limit | |
| else: | |
| activation = getattr(self.experts, "activation", "silu") | |
| alpha = 1.702 | |
| limit = 7.0 | |
| # Get weight tensors | |
| if hasattr(self.experts, "gate_up_proj"): | |
| w1 = self.experts.gate_up_proj | |
| is_interleaved = True # GptOss uses interleaved layout | |
| elif hasattr(self.experts, "w1"): | |
| w1 = self.experts.w1 | |
| w3 = getattr(self.experts, "w3", None) | |
| if w3 is not None: | |
| w1 = torch.cat([w1, w3], dim=-1) | |
| is_interleaved = False | |
| else: | |
| raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute") | |
| if hasattr(self.experts, "down_proj"): | |
| w2 = self.experts.down_proj | |
| elif hasattr(self.experts, "w2"): | |
| w2 = self.experts.w2 | |
| else: | |
| raise AttributeError("experts module must have 'down_proj' or 'w2' attribute") | |
| # Get optional bias tensors | |
| w1_bias = getattr(self.experts, "gate_up_proj_bias", None) | |
| w2_bias = getattr(self.experts, "down_proj_bias", None) | |
| # Store original shape | |
| in_shape = x.size() | |
| # Route tokens to experts | |
| logits, expert_weights, expert_indices = route_tokens_cpu( | |
| x, | |
| self.router.weight, | |
| getattr(self.router, "bias", None), | |
| moe_top_k, | |
| moe_num_experts, | |
| moe_normalize_expert_weights, | |
| ) | |
| # Reshape input for fused MoE | |
| x_flat = x.view(-1, x.shape[-1]) | |
| # Call CPU fused MoE | |
| output = cpu_fused_moe( | |
| hidden_states=x_flat, | |
| w1=w1, | |
| w2=w2, | |
| topk_weights=expert_weights, | |
| topk_ids=expert_indices, | |
| w1_bias=w1_bias, | |
| w2_bias=w2_bias, | |
| activation=activation, | |
| alpha=alpha, | |
| limit=limit, | |
| is_interleaved=is_interleaved, | |
| ) | |
| # Restore original shape | |
| output = output.view(in_shape) | |
| return output, expert_weights | |
| # Export classes and functions | |
| __all__ = [ | |
| "MegaBlocksMoeMLP", | |
| "cpu_fused_moe", | |
| "route_tokens_cpu", | |
| "swigluoai_activation", | |
| "silu_and_mul_activation", | |
| ] | |