from transformers.modeling_outputs import TokenClassifierOutput import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union import logging import json from configuration_stacked import ImpressoConfig logger = logging.getLogger(__name__) def get_info(label_map): num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} return num_token_labels_dict class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel): config_class = ImpressoConfig _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config, temporal_fusion_strategy="baseline", num_years=327): super().__init__(config) self.num_token_labels_dict = get_info(config.label_map) self.config = config self.temporal_fusion_strategy = temporal_fusion_strategy # Load base model from config instead of from_pretrained to avoid conflicts base_config = config.pretrained_config if isinstance(base_config, dict): from transformers import BertConfig base_config = BertConfig(**base_config) self.model = AutoModel.from_config(base_config) self.model.config.use_cache = False self.model.config.pretraining_tp = 1 self.num_years = num_years classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob self.dropout = nn.Dropout(classifier_dropout) self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy, num_years=num_years) # Additional transformer layers self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads ), num_layers=2, ) self.token_classifiers = nn.ModuleDict({ task: nn.Linear(config.hidden_size, num_labels) for task, num_labels in self.num_token_labels_dict.items() }) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, token_labels: Optional[dict] = None, date_indices: Optional[torch.Tensor] = None, year_index: Optional[torch.Tensor] = None, decade_index: Optional[torch.Tensor] = None, century_index: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embeddings(input_ids) # Early cross-attention fusion if self.temporal_fusion_strategy == "early-cross-attention": year_emb = self.temporal_fusion.compute_time_embedding(year_index) # (B, H) inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb) bert_kwargs = { "inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None, "input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "head_mask": head_mask, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict, } if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]): bert_kwargs.pop("token_type_ids", None) bert_kwargs.pop("head_mask", None) outputs = self.model(**bert_kwargs) token_output = self.dropout(outputs[0]) # (B, T, H) hidden_states = list(outputs.hidden_states) if output_hidden_states else None # Pass through additional transformer layers token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(0, 1) # Apply fusion after transformer if needed if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]: token_output = self.temporal_fusion(token_output, year_index) if output_hidden_states: hidden_states.append(token_output) # add the final fused state task_logits = {} total_loss = 0 for task, classifier in self.token_classifiers.items(): logits = classifier(token_output) task_logits[task] = logits if token_labels and task in token_labels: loss_fct = CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_token_labels_dict[task]), token_labels[task].view(-1), ) total_loss += loss if not return_dict: output = (task_logits,) + outputs[2:] return ((total_loss,) + output) if total_loss != 0 else output return TokenClassifierOutput( loss=total_loss, logits=task_logits, hidden_states=tuple(hidden_states) if hidden_states is not None else None, attentions=outputs.attentions if output_attentions else None, ) class TemporalFusion(nn.Module): def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700): super().__init__() self.strategy = strategy self.hidden_size = hidden_size self.min_year = min_year self.max_year = min_year + num_years - 1 self.year_emb = nn.Embedding(num_years, hidden_size) if strategy == "concat": self.concat_proj = nn.Linear(hidden_size * 2, hidden_size) elif strategy == "film": self.film_gamma = nn.Linear(hidden_size, hidden_size) self.film_beta = nn.Linear(hidden_size, hidden_size) elif strategy == "adapter": self.adapter = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), ) elif strategy == "relative": self.relative_encoder = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.SiLU(), nn.LayerNorm(hidden_size), ) self.film_gamma = nn.Linear(hidden_size, hidden_size) self.film_beta = nn.Linear(hidden_size, hidden_size) elif strategy == "multiscale": self.decade_emb = nn.Embedding(1000, hidden_size) self.century_emb = nn.Embedding(100, hidden_size) elif strategy in ["early-cross-attention", "late-cross-attention"]: self.year_encoder = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.SiLU() ) self.cross_attn = TemporalCrossAttention(hidden_size) def compute_time_embedding(self, year_index): if self.strategy in ["early-cross-attention", "late-cross-attention"]: return self.year_encoder(self.year_emb(year_index)) elif self.strategy == "multiscale": year_index = year_index.long() year = year_index + self.min_year decade = (year // 10).long() century = (year // 100).long() return ( self.year_emb(year_index) + self.decade_emb(decade) + self.century_emb(century) ) else: return self.year_emb(year_index) def forward(self, token_output, year_index): B, T, H = token_output.size() if self.strategy == "baseline": return token_output year_emb = self.compute_time_embedding(year_index) if self.strategy == "concat": expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) fused = torch.cat([token_output, expanded_year], dim=-1) return self.concat_proj(fused) elif self.strategy == "film": gamma = self.film_gamma(year_emb).unsqueeze(1) beta = self.film_beta(year_emb).unsqueeze(1) return gamma * token_output + beta elif self.strategy == "adapter": return token_output + self.adapter(year_emb).unsqueeze(1) elif self.strategy == "add": expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) return token_output + expanded_year elif self.strategy == "relative": encoded = self.relative_encoder(year_emb) gamma = self.film_gamma(encoded).unsqueeze(1) beta = self.film_beta(encoded).unsqueeze(1) return gamma * token_output + beta elif self.strategy == "multiscale": expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1) return token_output + expanded_year elif self.strategy == "late-cross-attention": return self.cross_attn(token_output, year_emb) else: raise ValueError(f"Unknown fusion strategy: {self.strategy}") class TemporalCrossAttention(nn.Module): def __init__(self, hidden_size, num_heads=4): super().__init__() self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True) def forward(self, token_output, time_embedding): # token_output: (B, T, H), time_embedding: (B, H) time_as_seq = time_embedding.unsqueeze(1) # (B, 1, H) attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq) return token_output + attn_output # Register the model ExtendedMultitaskTimeModelForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")