Token Classification
Transformers
PyTorch
Safetensors
English
French
German
stacked_bert
v1.0.0
custom_code
Instructions to use impresso-project/ner-stacked-bert-multilingual-v1.1.0 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use impresso-project/ner-stacked-bert-multilingual-v1.1.0 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="impresso-project/ner-stacked-bert-multilingual-v1.1.0", trust_remote_code=True)# Load model directly from transformers import AutoModelForTokenClassification model = AutoModelForTokenClassification.from_pretrained("impresso-project/ner-stacked-bert-multilingual-v1.1.0", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from transformers.modeling_outputs import TokenClassifierOutput | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig | |
| from torch.nn import CrossEntropyLoss | |
| from typing import Optional, Tuple, Union | |
| import logging, json, os | |
| from .configuration_stacked import ImpressoConfig | |
| logger = logging.getLogger(__name__) | |
| def get_info(label_map): | |
| num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} | |
| return num_token_labels_dict | |
| class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel): | |
| config_class = ImpressoConfig | |
| _keys_to_ignore_on_load_missing = [r"position_ids"] | |
| def __init__(self, config, temporal_fusion_strategy="baseline", num_years=327): | |
| super().__init__(config) | |
| self.num_token_labels_dict = get_info(config.label_map) | |
| self.config = config | |
| self.temporal_fusion_strategy = temporal_fusion_strategy | |
| self.model = AutoModel.from_pretrained( | |
| config.pretrained_config["_name_or_path"], config=config.pretrained_config | |
| ) | |
| self.model.config.use_cache = False | |
| self.model.config.pretraining_tp = 1 | |
| self.num_years = num_years | |
| classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob | |
| self.dropout = nn.Dropout(classifier_dropout) | |
| self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy, | |
| num_years=num_years) | |
| # Additional transformer layers | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model=config.hidden_size, nhead=config.num_attention_heads | |
| ), | |
| num_layers=2, | |
| ) | |
| self.token_classifiers = nn.ModuleDict({ | |
| task: nn.Linear(config.hidden_size, num_labels) | |
| for task, num_labels in self.num_token_labels_dict.items() | |
| }) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| token_labels: Optional[dict] = None, | |
| date_indices: Optional[torch.Tensor] = None, | |
| year_index: Optional[torch.Tensor] = None, | |
| decade_index: Optional[torch.Tensor] = None, | |
| century_index: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embeddings(input_ids) | |
| # Early cross-attention fusion | |
| if self.temporal_fusion_strategy == "early-cross-attention": | |
| year_emb = self.temporal_fusion.compute_time_embedding(year_index) # (B, H) | |
| inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb) | |
| bert_kwargs = { | |
| "inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None, | |
| "input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None, | |
| "attention_mask": attention_mask, | |
| "token_type_ids": token_type_ids, | |
| "position_ids": position_ids, | |
| "head_mask": head_mask, | |
| "output_attentions": output_attentions, | |
| "output_hidden_states": output_hidden_states, | |
| "return_dict": return_dict, | |
| } | |
| if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]): | |
| bert_kwargs.pop("token_type_ids", None) | |
| bert_kwargs.pop("head_mask", None) | |
| outputs = self.model(**bert_kwargs) | |
| token_output = self.dropout(outputs[0]) # (B, T, H) | |
| hidden_states = list(outputs.hidden_states) if output_hidden_states else None | |
| # Pass through additional transformer layers | |
| token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( | |
| 0, 1 | |
| ) | |
| # Apply fusion after transformer if needed | |
| if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]: | |
| token_output = self.temporal_fusion(token_output, year_index) | |
| if output_hidden_states: | |
| hidden_states.append(token_output) # add the final fused state | |
| task_logits = {} | |
| total_loss = 0 | |
| for task, classifier in self.token_classifiers.items(): | |
| logits = classifier(token_output) | |
| task_logits[task] = logits | |
| if token_labels and task in token_labels: | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct( | |
| logits.view(-1, self.num_token_labels_dict[task]), | |
| token_labels[task].view(-1), | |
| ) | |
| total_loss += loss | |
| if not return_dict: | |
| output = (task_logits,) + outputs[2:] | |
| return ((total_loss,) + output) if total_loss != 0 else output | |
| return TokenClassifierOutput( | |
| loss=total_loss, | |
| logits=task_logits, | |
| hidden_states=tuple(hidden_states) if hidden_states is not None else None, | |
| attentions=outputs.attentions if output_attentions else None, | |
| ) | |
| class TemporalFusion(nn.Module): | |
| def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700): | |
| super().__init__() | |
| self.strategy = strategy | |
| self.hidden_size = hidden_size | |
| self.min_year = min_year | |
| self.max_year = min_year + num_years - 1 | |
| self.year_emb = nn.Embedding(num_years, hidden_size) | |
| if strategy == "concat": | |
| self.concat_proj = nn.Linear(hidden_size * 2, hidden_size) | |
| elif strategy == "film": | |
| self.film_gamma = nn.Linear(hidden_size, hidden_size) | |
| self.film_beta = nn.Linear(hidden_size, hidden_size) | |
| elif strategy == "adapter": | |
| self.adapter = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, hidden_size), | |
| ) | |
| elif strategy == "relative": | |
| self.relative_encoder = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.SiLU(), | |
| nn.LayerNorm(hidden_size), | |
| ) | |
| self.film_gamma = nn.Linear(hidden_size, hidden_size) | |
| self.film_beta = nn.Linear(hidden_size, hidden_size) | |
| elif strategy == "multiscale": | |
| self.decade_emb = nn.Embedding(1000, hidden_size) | |
| self.century_emb = nn.Embedding(100, hidden_size) | |
| elif strategy in ["early-cross-attention", "late-cross-attention"]: | |
| self.year_encoder = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.SiLU() | |
| ) | |
| self.cross_attn = TemporalCrossAttention(hidden_size) | |
| def compute_time_embedding(self, year_index): | |
| if self.strategy in ["early-cross-attention", "late-cross-attention"]: | |
| return self.year_encoder(self.year_emb(year_index)) | |
| elif self.strategy == "multiscale": | |
| year_index = year_index.long() | |
| year = year_index + self.min_year | |
| decade = (year // 10).long() | |
| century = (year // 100).long() | |
| return ( | |
| self.year_emb(year_index) + | |
| self.decade_emb(decade) + | |
| self.century_emb(century) | |
| ) | |
| else: | |
| return self.year_emb(year_index) | |
| def forward(self, token_output, year_index): | |
| B, T, H = token_output.size() | |
| if self.strategy == "baseline": | |
| return token_output | |
| year_emb = self.compute_time_embedding(year_index) | |
| if self.strategy == "concat": | |
| expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) | |
| fused = torch.cat([token_output, expanded_year], dim=-1) | |
| return self.concat_proj(fused) | |
| elif self.strategy == "film": | |
| gamma = self.film_gamma(year_emb).unsqueeze(1) | |
| beta = self.film_beta(year_emb).unsqueeze(1) | |
| return gamma * token_output + beta | |
| elif self.strategy == "adapter": | |
| return token_output + self.adapter(year_emb).unsqueeze(1) | |
| elif self.strategy == "add": | |
| expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1) | |
| return token_output + expanded_year | |
| elif self.strategy == "relative": | |
| encoded = self.relative_encoder(year_emb) | |
| gamma = self.film_gamma(encoded).unsqueeze(1) | |
| beta = self.film_beta(encoded).unsqueeze(1) | |
| return gamma * token_output + beta | |
| elif self.strategy == "multiscale": | |
| expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1) | |
| return token_output + expanded_year | |
| elif self.strategy == "late-cross-attention": | |
| return self.cross_attn(token_output, year_emb) | |
| else: | |
| raise ValueError(f"Unknown fusion strategy: {self.strategy}") | |
| class TemporalCrossAttention(nn.Module): | |
| def __init__(self, hidden_size, num_heads=4): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True) | |
| def forward(self, token_output, time_embedding): | |
| # token_output: (B, T, H), time_embedding: (B, H) | |
| time_as_seq = time_embedding.unsqueeze(1) # (B, 1, H) | |
| attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq) | |
| return token_output + attn_output | |