Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn.functional as F | |
| from collections import namedtuple | |
| from kernels.benchmark import Benchmark | |
| def moe_mlp_reference( | |
| x: torch.Tensor, | |
| router_weight: torch.Tensor, | |
| router_bias: torch.Tensor, | |
| gate_up_proj: torch.Tensor, | |
| gate_up_proj_bias: torch.Tensor, | |
| down_proj: torch.Tensor, | |
| down_proj_bias: torch.Tensor, | |
| top_k: int = 4, | |
| alpha: float = 1.702, | |
| limit: float = 7.0, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| in_shape = x.shape | |
| num_experts = router_weight.shape[0] | |
| hidden_size = x.shape[-1] | |
| # Flatten to (num_tokens, hidden_size) | |
| hidden_states = x.view(-1, hidden_size) | |
| num_tokens = hidden_states.shape[0] | |
| # Router: compute logits and get top-k experts | |
| logits = F.linear(hidden_states, router_weight, router_bias) | |
| expert_weights, router_indices = torch.topk(logits, top_k, dim=-1) | |
| routing_weights = F.softmax(expert_weights, dim=-1) | |
| # Initialize output | |
| next_states = torch.zeros_like(hidden_states) | |
| # Create expert mask using one_hot | |
| with torch.no_grad(): | |
| expert_mask = F.one_hot(router_indices, num_classes=num_experts) | |
| expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, num_tokens) | |
| # Find which experts are hit | |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | |
| # Process each expert that has tokens | |
| for expert_idx in expert_hit: | |
| expert_idx = expert_idx[0] | |
| with torch.no_grad(): | |
| top_k_idx, token_idx = torch.where(expert_mask[expert_idx]) | |
| current_state = hidden_states[token_idx] | |
| # Up projection | |
| gate_up = ( | |
| current_state @ gate_up_proj[expert_idx] + gate_up_proj_bias[expert_idx] | |
| ) | |
| # Split into gate and up | |
| gate, up = gate_up[..., ::2], gate_up[..., 1::2] | |
| # Clamp | |
| gate = gate.clamp(min=None, max=limit) | |
| up = up.clamp(min=-limit, max=limit) | |
| # SwiGLU-like activation | |
| glu = gate * torch.sigmoid(gate * alpha) | |
| gated_output = (up + 1) * glu | |
| # Down projection | |
| out = gated_output @ down_proj[expert_idx] + down_proj_bias[expert_idx] | |
| # Get the routing weight for this expert at the correct top_k position | |
| weights_for_expert = routing_weights[token_idx, top_k_idx] | |
| weighted_output = out * weights_for_expert[:, None] | |
| next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) | |
| return next_states.view(in_shape), routing_weights | |
| class MegaBlocksMoeBenchmark(Benchmark): | |
| seed: int = 42 | |
| def setup(self): | |
| # Config matching readme_example.py | |
| ne, hs, isz = 128, 1152, 3072 | |
| batch, seq = 8, 1 | |
| # Router | |
| self.router_weight = torch.randn( | |
| ne, hs, device=self.device, dtype=torch.float32 | |
| ) | |
| torch.nn.init.kaiming_uniform_(self.router_weight) | |
| self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32) | |
| # Expert weights | |
| self.gate_up_proj = ( | |
| torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02 | |
| ) | |
| self.gate_up_proj_bias = torch.zeros( | |
| ne, isz, device=self.device, dtype=torch.float32 | |
| ) | |
| self.down_proj = ( | |
| torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32) | |
| * 0.02 | |
| ) | |
| self.down_proj_bias = torch.zeros( | |
| ne, hs, device=self.device, dtype=torch.float32 | |
| ) | |
| # Input | |
| self.x = ( | |
| torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1 | |
| ) | |
| # Setup the model | |
| self.model = self.kernel.layers.MegaBlocksMoeMLP() | |
| self.model.router = torch.nn.Linear(hs, ne, device=self.device) | |
| self.model.router.weight.data = self.router_weight.clone() | |
| self.model.router.bias.data = self.router_bias.clone() | |
| Experts = namedtuple( | |
| "Experts", | |
| [ | |
| "gate_up_proj", | |
| "gate_up_proj_bias", | |
| "down_proj", | |
| "down_proj_bias", | |
| "hidden_size", | |
| "num_experts", | |
| ], | |
| ) | |
| self.model.experts = Experts( | |
| gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()), | |
| gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()), | |
| down_proj=torch.nn.Parameter(self.down_proj.clone()), | |
| down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()), | |
| hidden_size=hs, | |
| num_experts=ne, | |
| ) | |
| self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32) | |
| def benchmark_base(self): | |
| self.out, self.expert_weights = self.model(self.x) | |
| def verify_base(self) -> torch.Tensor: | |
| ref_out, _ = moe_mlp_reference( | |
| self.x, | |
| self.router_weight, | |
| self.router_bias, | |
| self.gate_up_proj, | |
| self.gate_up_proj_bias, | |
| self.down_proj, | |
| self.down_proj_bias, | |
| top_k=4, | |
| ) | |
| return ref_out | |
| def setup_large(self): | |
| # Larger config with more tokens | |
| ne, hs, isz = 128, 1152, 3072 | |
| batch, seq = 32, 16 | |
| # Router | |
| self.router_weight = torch.randn( | |
| ne, hs, device=self.device, dtype=torch.float32 | |
| ) | |
| torch.nn.init.kaiming_uniform_(self.router_weight) | |
| self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32) | |
| # Expert weights | |
| self.gate_up_proj = ( | |
| torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02 | |
| ) | |
| self.gate_up_proj_bias = torch.zeros( | |
| ne, isz, device=self.device, dtype=torch.float32 | |
| ) | |
| self.down_proj = ( | |
| torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32) | |
| * 0.02 | |
| ) | |
| self.down_proj_bias = torch.zeros( | |
| ne, hs, device=self.device, dtype=torch.float32 | |
| ) | |
| # Input | |
| self.x = ( | |
| torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1 | |
| ) | |
| # Setup the model | |
| self.model = self.kernel.layers.MegaBlocksMoeMLP() | |
| self.model.router = torch.nn.Linear(hs, ne, device=self.device) | |
| self.model.router.weight.data = self.router_weight.clone() | |
| self.model.router.bias.data = self.router_bias.clone() | |
| Experts = namedtuple( | |
| "Experts", | |
| [ | |
| "gate_up_proj", | |
| "gate_up_proj_bias", | |
| "down_proj", | |
| "down_proj_bias", | |
| "hidden_size", | |
| "num_experts", | |
| "capacity_factor", | |
| ], | |
| ) | |
| self.model.experts = Experts( | |
| gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()), | |
| gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()), | |
| down_proj=torch.nn.Parameter(self.down_proj.clone()), | |
| down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()), | |
| hidden_size=hs, | |
| num_experts=ne, | |
| capacity_factor=4.0, # Higher capacity to avoid token dropping | |
| ) | |
| self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32) | |
| def benchmark_large(self): | |
| self.out, self.expert_weights = self.model(self.x) | |
| def verify_large(self) -> torch.Tensor: | |
| ref_out, _ = moe_mlp_reference( | |
| self.x, | |
| self.router_weight, | |
| self.router_bias, | |
| self.gate_up_proj, | |
| self.gate_up_proj_bias, | |
| self.down_proj, | |
| self.down_proj_bias, | |
| top_k=4, | |
| ) | |
| return ref_out | |