ner-newsagency-bert-fr / newsagency_ner.py
emanuelaboros's picture
Upload folder using huggingface_hub
53bc89f verified
Raw
History Blame
8.39 kB
from transformers import (
AutoTokenizer,
Pipeline,
)
import numpy as np
import torch
from torch import nn
from nltk.chunk import conlltags2tree
from nltk import pos_tag
from nltk.tree import Tree
import string
label2id = {
"B-org.ent.pressagency.Reuters": 0,
"B-org.ent.pressagency.Stefani": 1,
"O": 2,
"B-org.ent.pressagency.Extel": 3,
"B-org.ent.pressagency.Havas": 4,
"I-org.ent.pressagency.Xinhua": 5,
"I-org.ent.pressagency.Domei": 6,
"B-org.ent.pressagency.Belga": 7,
"B-org.ent.pressagency.CTK": 8,
"B-org.ent.pressagency.ANSA": 9,
"B-org.ent.pressagency.DNB": 10,
"B-org.ent.pressagency.Domei": 11,
"I-pers.ind.articleauthor": 12,
"I-org.ent.pressagency.Wolff": 13,
"B-org.ent.pressagency.unk": 14,
"I-org.ent.pressagency.Stefani": 15,
"I-org.ent.pressagency.AFP": 16,
"B-org.ent.pressagency.UP-UPI": 17,
"I-org.ent.pressagency.ATS-SDA": 18,
"I-org.ent.pressagency.unk": 19,
"B-org.ent.pressagency.DPA": 20,
"B-org.ent.pressagency.AFP": 21,
"I-org.ent.pressagency.DNB": 22,
"B-pers.ind.articleauthor": 23,
"I-org.ent.pressagency.UP-UPI": 24,
"B-org.ent.pressagency.Kipa": 25,
"B-org.ent.pressagency.Wolff": 26,
"B-org.ent.pressagency.ag": 27,
"I-org.ent.pressagency.Extel": 28,
"I-org.ent.pressagency.ag": 29,
"B-org.ent.pressagency.ATS-SDA": 30,
"I-org.ent.pressagency.Havas": 31,
"I-org.ent.pressagency.Reuters": 32,
"B-org.ent.pressagency.Xinhua": 33,
"B-org.ent.pressagency.AP": 34,
"B-org.ent.pressagency.APA": 35,
"I-org.ent.pressagency.ANSA": 36,
"B-org.ent.pressagency.DDP-DAPD": 37,
"I-org.ent.pressagency.TASS": 38,
"I-org.ent.pressagency.AP": 39,
"B-org.ent.pressagency.TASS": 40,
"B-org.ent.pressagency.Europapress": 41,
"B-org.ent.pressagency.SPK-SMP": 42,
}
id2label = {v: k for k, v in label2id.items()}
def tokenize(text):
# print(text)
for punctuation in string.punctuation:
text = text.replace(punctuation, " " + punctuation + " ")
return text.split()
def get_entities(tokens, tags):
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
pos_tags = [pos for token, pos in pos_tag(tokens)]
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
ne_tree = conlltags2tree(conlltags)
entities = []
idx = 0
char_position = 0 # This will hold the current character position
for subtree in ne_tree:
# skipping 'O' tags
if isinstance(subtree, Tree):
original_label = subtree.label()
original_string = " ".join([token for token, pos in subtree.leaves()])
entity_start_position = char_position
entity_end_position = entity_start_position + len(original_string)
entities.append(
(
original_string,
original_label,
(idx, idx + len(subtree)),
(entity_start_position, entity_end_position),
)
)
idx += len(subtree)
# Update the current character position
# We add the length of the original string + 1 (for the space)
char_position += len(original_string) + 1
else:
token, pos = subtree
# If it's not a named entity, we still need to update the character
# position
char_position += len(token) + 1 # We add 1 for the space
idx += 1
return entities
def realign(text_sentence, out_label_preds, tokenizer, reverted_label_map):
preds_list, words_list, confidence_list = [], [], []
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
for idx, word in enumerate(text_sentence):
try:
beginning_index = word_ids.index(idx)
preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
except Exception as ex: # the sentence was longer then max_length
preds_list.append("O")
words_list.append(word)
return words_list, preds_list
class NewsAgencyModelPipeline(Pipeline):
# def __init__(self, model_id, config, **kwargs):
# super().__init__(model_id, config, **kwargs)
# self.tokenizer = AutoTokenizer.from_pretrained(model_id)
#
# self.model = ModelForSequenceAndTokenClassification.from_pretrained(
# model_id,
# num_sequence_labels=2,
# num_token_labels=len(label2id),
# )
# self.model.eval() # Set the model to evaluation mode
# def __init__(self, model, tokenizer, **kwargs):
# super().__init__(self, model, tokenizer, **kwargs)
# self.model = model
# self.tokenizer = tokenizer
def _sanitize_parameters(self, **kwargs):
# Add any additional parameter handling if necessary
return kwargs, {}, {}
def preprocess(self, text, **kwargs):
tokenized_inputs = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=128,
# We use this argument because the texts in our dataset are lists
# of words (with a label for each word).
# is_split_into_words=True
)
text_sentence = tokenize(text)
return tokenized_inputs, text_sentence
def _forward(self, inputs):
inputs, text_sentence = inputs
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
self.model.device
)
with torch.no_grad():
outputs = self.model(input_ids)
return outputs, text_sentence
def postprocess(self, outputs, **kwargs):
# postprocess the outputs here, for example, convert predictions to labels
# outputs = ... # some processing here
outputs, text_sentence = outputs
try:
_, tokens_result = outputs[0], outputs[1]
except:
tokens_result = outputs[0]
tokens_result = np.argmax(
tokens_result["logits"].detach().cpu().numpy(), axis=2
)[0]
words_list, preds_list = realign(
text_sentence,
tokens_result,
self.tokenizer,
id2label,
)
entities = get_entities(words_list, preds_list)
# print('*'*20, 'Result:', entities)
return [entities]
# def postprocess(self, outputs, **kwargs):
#
# # Extract and process logits
# outputs, inputs = outputs[0], outputs[1]
#
# token_logits, sequence_logits = outputs[0], outputs[1]
#
# token_logits = token_logits.logits.detach().cpu().numpy()
# sequence_logits = sequence_logits.logits.detach().cpu().numpy()
#
# text_sentences = [
# self.tokenizer.convert_ids_to_tokens(input_ids)
# for input_ids in inputs["input_ids"].detach().cpu().numpy()
# ]
#
# sequence_preds = np.argmax(token_logits, axis=-1)
# token_preds = np.argmax(sequence_logits, axis=1)
#
# # sequence_preds = torch.argmax(sequence_logits, dim=-1)
# # token_preds = torch.argmax(token_logits, dim=-1)
#
# preds_list = [[] for _ in range(token_preds.shape[0])]
# words_list = [[] for _ in range(token_preds.shape[0])]
#
# for idx_sentence, item in enumerate(zip(text_sentences, token_preds)):
# text_sentence, out_label_preds = item
# word_ids = self.tokenizer(
# text_sentence, is_split_into_words=True
# ).word_ids()
# for idx, word in enumerate(text_sentence):
# beginning_index = word_ids.index(idx)
#
# try:
# preds_list[idx_sentence].append(
# id2label[out_label_preds[beginning_index]]
# )
# except BaseException: # the sentence was longer then max_length
# preds_list[idx_sentence].append("O")
# words_list[idx_sentence].append(word)
#
# import pdb
#
# pdb.set_trace()
# return {
# "sequence_classification": sequence_preds.cpu().numpy(),
# "token_classification": token_preds.cpu().numpy(),
# }