"""A GPT model with the SD-PReLU (DTSG-PReLU) MLP activation, for HF Transformers. The backbone is GPT-2's (attention, layernorms, embeddings, tied head), but the MLP activation is no longer GELU, so these classes drop the "GPT2" name. This mirrors the llm.kittens `-af sd-prelu` activation exactly: per layer there are two learnable scalars (theta_a, theta_b), materialized through a bounded reparameterization and applied elementwise between the up- and down-projections: a = alpha_max * sigmoid(theta_a) # in [0, alpha_max) b = beta_min + softplus(theta_b) # in (beta_min, inf) phi(x) = x * (a + (1 - a) * sigmoid(b * x)) To match the CUDA kernel bit-for-bit in spirit, the materialization and the gate sigmoid clamp their argument to [-20, 20] and the activation is computed in float32 (inputs/outputs stay in the surrounding model dtype, typically bfloat16). Load with: from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True) tok = AutoTokenizer.from_pretrained(path) """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Config, GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2MLP def sdprelu( x: torch.Tensor, theta_a: torch.Tensor, theta_b: torch.Tensor, alpha_max: float, beta_min: float, ) -> torch.Tensor: """SD-PReLU activation: x * (a + (1 - a) * sigmoid(b * x)), computed in fp32.""" orig_dtype = x.dtype x = x.float() a = alpha_max * torch.sigmoid(theta_a.float().clamp(-20.0, 20.0)) b = beta_min + F.softplus(theta_b.float()) s = torch.sigmoid((b * x).clamp(-20.0, 20.0)) out = x * (a + (1.0 - a) * s) return out.to(orig_dtype) class GPTSDPReLUConfig(GPT2Config): model_type = "gpt-sdprelu" def __init__( self, sdprelu_alpha_max: float = 0.30, sdprelu_beta_min: float = 0.50, **kwargs, ): self.sdprelu_alpha_max = sdprelu_alpha_max self.sdprelu_beta_min = sdprelu_beta_min super().__init__(**kwargs) class GPTSDPReLUMLP(nn.Module): """GPT-2 MLP projections with GELU swapped for per-layer SD-PReLU. Reuses the c_fc / c_proj / dropout submodules of an existing GPT2MLP so the weight names stay `mlp.c_fc.*` / `mlp.c_proj.*`; adds two scalar parameters `theta_a` and `theta_b` per layer. """ def __init__(self, base_mlp: GPT2MLP, config: GPTSDPReLUConfig): super().__init__() self.c_fc = base_mlp.c_fc self.c_proj = base_mlp.c_proj self.dropout = base_mlp.dropout self.alpha_max = float(config.sdprelu_alpha_max) self.beta_min = float(config.sdprelu_beta_min) self.theta_a = nn.Parameter(torch.zeros(1)) self.theta_b = nn.Parameter(torch.zeros(1)) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.c_fc(hidden_states) hidden_states = sdprelu( hidden_states, self.theta_a, self.theta_b, self.alpha_max, self.beta_min ) hidden_states = self.c_proj(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class GPTSDPReLULMHeadModel(GPT2LMHeadModel): config_class = GPTSDPReLUConfig def __init__(self, config: GPTSDPReLUConfig): super().__init__(config) for block in self.transformer.h: block.mlp = GPTSDPReLUMLP(block.mlp, config) GPTSDPReLUConfig.register_for_auto_class() GPTSDPReLULMHeadModel.register_for_auto_class("AutoModelForCausalLM")