sproto / configuration_sproto.py
RamezCh's picture
Upload folder using huggingface_hub
127b6bf verified
from transformers import PretrainedConfig
class SprotoConfig(PretrainedConfig):
model_type = "sproto"
def __init__(
self,
pretrained_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
num_classes=55,
label_order_path=None,
use_attention=True,
use_global_attention=False,
dot_product=False,
normalize=None,
final_layer=False,
reduce_hidden_size=None,
num_prototypes_per_class=1,
loss="BCE",
use_prototype_loss=False,
use_sigmoid=False,
seed=7,
vocab_size=28996,
hidden_size=768,
max_position_embeddings=512,
attention_probs_dropout_prob=0.1,
hidden_dropout_prob=0.1,
**kwargs,
):
# Bug fix: the checkpoint serialises the num_prototypes_per_class buffer as a list
# (one float per class). MultiProtoModule expects a scalar int for the uniform case.
# Collapse the list → scalar, asserting all values are identical.
if isinstance(num_prototypes_per_class, (list, tuple)):
unique_vals = set(int(v) for v in num_prototypes_per_class)
assert len(unique_vals) == 1, (
"Non-uniform num_prototypes_per_class cannot be represented as a scalar "
"config value. Got: {}".format(unique_vals)
)
num_prototypes_per_class = unique_vals.pop()
self.pretrained_model = pretrained_model
self.num_classes = num_classes
self.label_order_path = label_order_path
self.use_attention = use_attention
self.use_global_attention = use_global_attention
self.dot_product = dot_product
self.normalize = normalize
self.final_layer = final_layer
self.reduce_hidden_size = reduce_hidden_size
self.num_prototypes_per_class = num_prototypes_per_class
self.loss = loss
self.use_prototype_loss = use_prototype_loss
self.use_sigmoid = use_sigmoid
self.seed = seed
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_dropout_prob = hidden_dropout_prob
super().__init__(**kwargs)