sproto / modeling_sproto.py
RamezCh's picture
Upload folder using huggingface_hub
127b6bf verified
import torch
from dataclasses import dataclass
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from sproto.model.multi_proto import MultiProtoModule
from .configuration_sproto import SprotoConfig
@dataclass
class SprotoOutput(ModelOutput):
logits: torch.Tensor = None
max_indices: torch.Tensor = None
metadata: dict = None
def __contains__(self, key):
return hasattr(self, key)
def __getitem__(self, key):
return getattr(self, key)
class SprotoModel(PreTrainedModel):
config_class = SprotoConfig
base_model_prefix = "sproto"
def __init__(self, config):
super().__init__(config)
self.config = config
# HF's from_pretrained (on newer PyTorch / Transformers) can set a meta-device
# context during weight loading. MultiProtoModule.__init__ internally calls
# AutoModel.from_pretrained() for the BERT backbone, which conflicts with that
# context. We escape it by forcing CPU initialization when on torch >= 2.0
# (torch.device context manager was added in 2.0); on older torch the conflict
# does not occur and we construct normally.
_torch_version = tuple(int(x) for x in torch.__version__.split(".")[:2] if x.isdigit())
if _torch_version >= (2, 0):
with torch.device("cpu"):
self.module = MultiProtoModule(
pretrained_model=config.pretrained_model,
num_classes=config.num_classes,
label_order_path=config.label_order_path,
use_attention=config.use_attention,
use_global_attention=config.use_global_attention,
dot_product=config.dot_product,
normalize=config.normalize,
final_layer=config.final_layer,
reduce_hidden_size=config.reduce_hidden_size,
num_prototypes_per_class=config.num_prototypes_per_class,
loss=config.loss,
use_prototype_loss=config.use_prototype_loss,
use_sigmoid=config.use_sigmoid,
seed=config.seed,
# use_cuda=False: HF handles device placement via .to(device);
# manual .cuda() calls in the base model are bypassed here.
use_cuda=False,
)
else:
self.module = MultiProtoModule(
pretrained_model=config.pretrained_model,
num_classes=config.num_classes,
label_order_path=config.label_order_path,
use_attention=config.use_attention,
use_global_attention=config.use_global_attention,
dot_product=config.dot_product,
normalize=config.normalize,
final_layer=config.final_layer,
reduce_hidden_size=config.reduce_hidden_size,
num_prototypes_per_class=config.num_prototypes_per_class,
loss=config.loss,
use_prototype_loss=config.use_prototype_loss,
use_sigmoid=config.use_sigmoid,
seed=config.seed,
use_cuda=False,
)
def forward(
self,
input_ids,
attention_mask,
token_type_ids=None,
targets=None,
tokens=None,
sample_ids=None,
):
# tokens MUST be provided when use_attention=True.
# attention_mask_from_tokens() relies on real token strings to zero out clinical
# section headers ([CLS], [SEP], "chief complaint :", etc.) before computing
# token-prototype attention. A fake default would silently produce wrong logits.
if tokens is None and self.config.use_attention:
raise ValueError(
"tokens (list-of-lists of token strings per sample) must be provided "
"when use_attention=True. Obtain them via:\n"
" tokenizer.convert_ids_to_tokens(input_ids[i])\n"
"for each sample i in the batch, or pass the full batch at once with:\n"
" [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]"
)
batch = {
"input_ids": input_ids,
"attention_masks": attention_mask,
"token_type_ids": token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids),
"targets": targets if targets is not None else torch.zeros(input_ids.shape[0], self.config.num_classes),
"tokens": tokens if tokens is not None else [["[PAD]"] * input_ids.shape[1]] * input_ids.shape[0],
"sample_ids": sample_ids if sample_ids is not None else [f"sample_{i}" for i in range(input_ids.shape[0])],
}
logits, max_indices, metadata = self.module(batch)
return SprotoOutput(
logits=logits,
max_indices=max_indices,
metadata=metadata,
)
def get_embeddings(self, input_ids, attention_mask, token_type_ids=None):
bert_output = self.module.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids),
)
return bert_output.last_hidden_state
def _init_weights(self, module):
pass