from transformers.configuration_utils import PretrainedConfig class FM4BioConfig(PretrainedConfig): model_type = "fm4bio" def __init__( self, vocab_size=128, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, hidden_act="swiglu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=2048, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-05, pad_token_id=0, add_linear_bias=True, position_embedding_type="rope", normalization_type="RMSNorm", use_cache=True, rotary_percent=1.0, seq_len_interpolation_factor=None, moe=False, num_experts=0, experts_per_token=0, use_lm_head=True, tie_word_embeddings=True, output_vocab_size: int = None, # when set, the output vocab size is different from the input vocab size gradient_checkpointing=False, # << Pan: Gradient checkpoint for memory saving **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.add_linear_bias = add_linear_bias assert normalization_type in [ "RMSNorm", "LayerNorm", ], "normalization_type must be 'RMSNorm' or 'LayerNorm'" self.normalization_type = normalization_type self.rotary_percent = rotary_percent self.seq_len_interpolation_factor = seq_len_interpolation_factor self.moe = moe self.num_experts = num_experts self.experts_per_token = experts_per_token self.use_lm_head = use_lm_head self.tie_word_embeddings = tie_word_embeddings self.output_vocab_size = output_vocab_size self.gradient_checkpointing = gradient_checkpointing # << Pan: Gradient checkpoint for memory saving