Kernels
danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
64a7ea9 verified
# SPDX-License-Identifier: Apache-2.0
# MegaBlocks CPU Fused MoE Implementation
#
# This is a pure Python/PyTorch implementation for CPU.
# For better performance, consider using the C++ kernel implementation.
#
import torch
import torch.nn.functional as F
def swigluoai_activation(gate: torch.Tensor, up: torch.Tensor,
alpha: float = 1.702, limit: float = 7.0) -> torch.Tensor:
"""
SwigluOAI activation function used in GptOss models.
Formula:
gate = clamp(gate, max=limit)
up = clamp(up, -limit, limit)
glu = gate * sigmoid(gate * alpha)
output = (up + 1) * glu
Args:
gate: Gate tensor from gate projection
up: Up tensor from up projection
alpha: Scaling factor for sigmoid (default: 1.702)
limit: Clamp limit (default: 7.0)
Returns:
Activated tensor
"""
gate = gate.clamp(max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
return (up + 1) * glu
def silu_and_mul_activation(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""
SiLU (Swish) activation with element-wise multiplication.
Formula:
output = silu(gate) * up
Args:
gate: Gate tensor
up: Up tensor
Returns:
Activated tensor
"""
return F.silu(gate) * up
def route_tokens_cpu(
x: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor | None,
moe_top_k: int,
moe_num_experts: int,
moe_normalize_expert_weights: int | None = None,
) -> tuple:
"""
Route tokens to experts and compute expert weights and indices (CPU version).
Args:
x: Input tensor [batch, seq, hidden] or [tokens, hidden]
router_weight: Router weight [num_experts, hidden]
router_bias: Router bias [num_experts] or None
moe_top_k: Number of experts per token
moe_num_experts: Total number of experts
moe_normalize_expert_weights: Normalization order or None
Returns:
Tuple of (logits, expert_weights, expert_indices)
"""
x_flat = x.view(-1, x.shape[-1])
logits = F.linear(x_flat, router_weight, router_bias)
if moe_top_k == 1:
expert_weights, expert_indices = logits.max(dim=-1, keepdim=True)
else:
expert_weights, expert_indices = torch.topk(logits, moe_top_k, dim=-1)
expert_weights = expert_weights.softmax(dim=-1)
if moe_normalize_expert_weights is not None:
expert_weights = expert_weights / torch.norm(
expert_weights,
p=moe_normalize_expert_weights,
dim=-1,
keepdim=True,
)
return logits, expert_weights, expert_indices
def cpu_fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
activation: str = "silu",
alpha: float = 1.702,
limit: float = 7.0,
is_interleaved: bool = True,
) -> torch.Tensor:
"""
CPU Fused MoE using PyTorch operations.
This implementation processes all experts in parallel using batched operations
instead of sequential for loops, which is more efficient on CPU.
Args:
hidden_states: [num_tokens, hidden_size]
w1: [num_experts, hidden_size, 2*inter_size] - gate_up_proj weights
w2: [num_experts, inter_size, hidden_size] - down_proj weights
topk_weights: [num_tokens, topk] - routing weights
topk_ids: [num_tokens, topk] - expert indices
w1_bias: [num_experts, 2*inter_size] or None
w2_bias: [num_experts, hidden_size] or None
activation: "silu" or "swigluoai"
alpha: swigluoai alpha parameter
limit: swigluoai limit parameter
is_interleaved: whether gate_up is interleaved [g0,u0,g1,u1,...] (True for GptOss)
Returns:
output: [num_tokens, hidden_size]
"""
num_tokens, hidden_size = hidden_states.shape
num_experts = w1.shape[0]
inter_size = w2.shape[1]
topk = topk_weights.shape[1]
# Initialize output
output = torch.zeros_like(hidden_states)
# Build expert mask: which tokens go to which expert
# expert_mask[expert_id] contains indices of (token_idx, topk_pos) pairs
for expert_idx in range(num_experts):
# Find tokens assigned to this expert
# mask shape: [num_tokens, topk], True where topk_ids == expert_idx
mask = (topk_ids == expert_idx)
if not mask.any():
continue
# Get token indices and topk positions
token_indices, topk_positions = torch.where(mask)
if len(token_indices) == 0:
continue
# Gather input tokens for this expert
# current_hidden: [num_selected_tokens, hidden_size]
current_hidden = hidden_states[token_indices]
# Get weights for this expert
# w1[expert_idx]: [hidden_size, 2*inter_size]
# w2[expert_idx]: [inter_size, hidden_size]
expert_w1 = w1[expert_idx] # [hidden_size, 2*inter_size]
expert_w2 = w2[expert_idx] # [inter_size, hidden_size]
# First projection: hidden @ w1 -> [num_selected, 2*inter_size]
gate_up = current_hidden @ expert_w1
# Add bias if present
if w1_bias is not None:
gate_up = gate_up + w1_bias[expert_idx]
# Split gate and up projections
if is_interleaved:
# GptOss uses interleaved layout: [g0, u0, g1, u1, ...]
gate = gate_up[..., ::2] # [num_selected, inter_size]
up = gate_up[..., 1::2] # [num_selected, inter_size]
else:
# Standard layout: [gate_all, up_all]
gate = gate_up[..., :inter_size]
up = gate_up[..., inter_size:]
# Apply activation
if activation == "swigluoai":
activated = swigluoai_activation(gate, up, alpha, limit)
else: # silu
activated = silu_and_mul_activation(gate, up)
# Second projection: activated @ w2 -> [num_selected, hidden_size]
expert_out = activated @ expert_w2
# Add bias if present
if w2_bias is not None:
expert_out = expert_out + w2_bias[expert_idx]
# Apply routing weights and accumulate
# weights shape: [num_selected]
weights = topk_weights[token_indices, topk_positions].unsqueeze(-1)
weighted_out = expert_out * weights
# Accumulate to output
output.index_add_(0, token_indices, weighted_out.to(output.dtype))
return output
class MegaBlocksMoeMLP(torch.nn.Module):
"""
CPU MoE MLP module that can be used as a drop-in replacement for
the transformers GptOssMLP when using @use_kernel_forward_from_hub.
"""
can_torch_compile: bool = True
def forward(self, x: torch.Tensor) -> tuple:
"""
Forward pass through the MoE layer.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_size] or [tokens, hidden_size]
Returns:
Tuple of (output, expert_weights) where:
- output: Tensor of same shape as input
- expert_weights: Expert weights for each token [tokens, top_k]
"""
# Get MoE parameters from the wrapped modules
moe_top_k = getattr(self.router, "top_k", 4)
moe_num_experts = getattr(self.experts, "num_experts", 128)
moe_normalize_expert_weights = getattr(
self.experts, "normalize_expert_weights", None
)
# Detect activation type
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
activation = "swigluoai"
alpha = self.experts.alpha
limit = self.experts.limit
else:
activation = getattr(self.experts, "activation", "silu")
alpha = 1.702
limit = 7.0
# Get weight tensors
if hasattr(self.experts, "gate_up_proj"):
w1 = self.experts.gate_up_proj
is_interleaved = True # GptOss uses interleaved layout
elif hasattr(self.experts, "w1"):
w1 = self.experts.w1
w3 = getattr(self.experts, "w3", None)
if w3 is not None:
w1 = torch.cat([w1, w3], dim=-1)
is_interleaved = False
else:
raise AttributeError("experts module must have 'gate_up_proj' or 'w1' attribute")
if hasattr(self.experts, "down_proj"):
w2 = self.experts.down_proj
elif hasattr(self.experts, "w2"):
w2 = self.experts.w2
else:
raise AttributeError("experts module must have 'down_proj' or 'w2' attribute")
# Get optional bias tensors
w1_bias = getattr(self.experts, "gate_up_proj_bias", None)
w2_bias = getattr(self.experts, "down_proj_bias", None)
# Store original shape
in_shape = x.size()
# Route tokens to experts
logits, expert_weights, expert_indices = route_tokens_cpu(
x,
self.router.weight,
getattr(self.router, "bias", None),
moe_top_k,
moe_num_experts,
moe_normalize_expert_weights,
)
# Reshape input for fused MoE
x_flat = x.view(-1, x.shape[-1])
# Call CPU fused MoE
output = cpu_fused_moe(
hidden_states=x_flat,
w1=w1,
w2=w2,
topk_weights=expert_weights,
topk_ids=expert_indices,
w1_bias=w1_bias,
w2_bias=w2_bias,
activation=activation,
alpha=alpha,
limit=limit,
is_interleaved=is_interleaved,
)
# Restore original shape
output = output.view(in_shape)
return output, expert_weights
# Export classes and functions
__all__ = [
"MegaBlocksMoeMLP",
"cpu_fused_moe",
"route_tokens_cpu",
"swigluoai_activation",
"silu_and_mul_activation",
]