emanuelaboros commited on
Commit
53bc89f
·
verified ·
1 Parent(s): f6e0a3e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +9 -0
  2. newsagency_ner.py +241 -0
config.json CHANGED
@@ -5,6 +5,15 @@
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
 
 
 
 
 
 
 
 
 
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
 
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
8
+ "custom_pipelines": {
9
+ "newsagency-ner": {
10
+ "impl": "newsagency_ner.NewsAgencyModelPipeline",
11
+ "pt": [
12
+ "ModelForSequenceAndTokenClassification"
13
+ ],
14
+ "tf": []
15
+ }
16
+ },
17
  "gradient_checkpointing": false,
18
  "hidden_act": "gelu",
19
  "hidden_dropout_prob": 0.1,
newsagency_ner.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoTokenizer,
3
+ Pipeline,
4
+ )
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from nltk.chunk import conlltags2tree
9
+ from nltk import pos_tag
10
+ from nltk.tree import Tree
11
+ import string
12
+
13
+ label2id = {
14
+ "B-org.ent.pressagency.Reuters": 0,
15
+ "B-org.ent.pressagency.Stefani": 1,
16
+ "O": 2,
17
+ "B-org.ent.pressagency.Extel": 3,
18
+ "B-org.ent.pressagency.Havas": 4,
19
+ "I-org.ent.pressagency.Xinhua": 5,
20
+ "I-org.ent.pressagency.Domei": 6,
21
+ "B-org.ent.pressagency.Belga": 7,
22
+ "B-org.ent.pressagency.CTK": 8,
23
+ "B-org.ent.pressagency.ANSA": 9,
24
+ "B-org.ent.pressagency.DNB": 10,
25
+ "B-org.ent.pressagency.Domei": 11,
26
+ "I-pers.ind.articleauthor": 12,
27
+ "I-org.ent.pressagency.Wolff": 13,
28
+ "B-org.ent.pressagency.unk": 14,
29
+ "I-org.ent.pressagency.Stefani": 15,
30
+ "I-org.ent.pressagency.AFP": 16,
31
+ "B-org.ent.pressagency.UP-UPI": 17,
32
+ "I-org.ent.pressagency.ATS-SDA": 18,
33
+ "I-org.ent.pressagency.unk": 19,
34
+ "B-org.ent.pressagency.DPA": 20,
35
+ "B-org.ent.pressagency.AFP": 21,
36
+ "I-org.ent.pressagency.DNB": 22,
37
+ "B-pers.ind.articleauthor": 23,
38
+ "I-org.ent.pressagency.UP-UPI": 24,
39
+ "B-org.ent.pressagency.Kipa": 25,
40
+ "B-org.ent.pressagency.Wolff": 26,
41
+ "B-org.ent.pressagency.ag": 27,
42
+ "I-org.ent.pressagency.Extel": 28,
43
+ "I-org.ent.pressagency.ag": 29,
44
+ "B-org.ent.pressagency.ATS-SDA": 30,
45
+ "I-org.ent.pressagency.Havas": 31,
46
+ "I-org.ent.pressagency.Reuters": 32,
47
+ "B-org.ent.pressagency.Xinhua": 33,
48
+ "B-org.ent.pressagency.AP": 34,
49
+ "B-org.ent.pressagency.APA": 35,
50
+ "I-org.ent.pressagency.ANSA": 36,
51
+ "B-org.ent.pressagency.DDP-DAPD": 37,
52
+ "I-org.ent.pressagency.TASS": 38,
53
+ "I-org.ent.pressagency.AP": 39,
54
+ "B-org.ent.pressagency.TASS": 40,
55
+ "B-org.ent.pressagency.Europapress": 41,
56
+ "B-org.ent.pressagency.SPK-SMP": 42,
57
+ }
58
+
59
+ id2label = {v: k for k, v in label2id.items()}
60
+
61
+
62
+ def tokenize(text):
63
+ # print(text)
64
+ for punctuation in string.punctuation:
65
+ text = text.replace(punctuation, " " + punctuation + " ")
66
+ return text.split()
67
+
68
+
69
+ def get_entities(tokens, tags):
70
+ tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
71
+ pos_tags = [pos for token, pos in pos_tag(tokens)]
72
+
73
+ conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
74
+ ne_tree = conlltags2tree(conlltags)
75
+
76
+ entities = []
77
+ idx = 0
78
+ char_position = 0 # This will hold the current character position
79
+
80
+ for subtree in ne_tree:
81
+ # skipping 'O' tags
82
+ if isinstance(subtree, Tree):
83
+ original_label = subtree.label()
84
+ original_string = " ".join([token for token, pos in subtree.leaves()])
85
+
86
+ entity_start_position = char_position
87
+ entity_end_position = entity_start_position + len(original_string)
88
+
89
+ entities.append(
90
+ (
91
+ original_string,
92
+ original_label,
93
+ (idx, idx + len(subtree)),
94
+ (entity_start_position, entity_end_position),
95
+ )
96
+ )
97
+ idx += len(subtree)
98
+
99
+ # Update the current character position
100
+ # We add the length of the original string + 1 (for the space)
101
+ char_position += len(original_string) + 1
102
+ else:
103
+ token, pos = subtree
104
+ # If it's not a named entity, we still need to update the character
105
+ # position
106
+ char_position += len(token) + 1 # We add 1 for the space
107
+ idx += 1
108
+
109
+ return entities
110
+
111
+
112
+ def realign(text_sentence, out_label_preds, tokenizer, reverted_label_map):
113
+ preds_list, words_list, confidence_list = [], [], []
114
+ word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
115
+ for idx, word in enumerate(text_sentence):
116
+
117
+ try:
118
+ beginning_index = word_ids.index(idx)
119
+ preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
120
+ except Exception as ex: # the sentence was longer then max_length
121
+ preds_list.append("O")
122
+ words_list.append(word)
123
+ return words_list, preds_list
124
+
125
+
126
+ class NewsAgencyModelPipeline(Pipeline):
127
+ # def __init__(self, model_id, config, **kwargs):
128
+ # super().__init__(model_id, config, **kwargs)
129
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_id)
130
+ #
131
+ # self.model = ModelForSequenceAndTokenClassification.from_pretrained(
132
+ # model_id,
133
+ # num_sequence_labels=2,
134
+ # num_token_labels=len(label2id),
135
+ # )
136
+ # self.model.eval() # Set the model to evaluation mode
137
+ # def __init__(self, model, tokenizer, **kwargs):
138
+ # super().__init__(self, model, tokenizer, **kwargs)
139
+ # self.model = model
140
+ # self.tokenizer = tokenizer
141
+
142
+ def _sanitize_parameters(self, **kwargs):
143
+ # Add any additional parameter handling if necessary
144
+ return kwargs, {}, {}
145
+
146
+ def preprocess(self, text, **kwargs):
147
+ tokenized_inputs = self.tokenizer(
148
+ text,
149
+ padding="max_length",
150
+ truncation=True,
151
+ max_length=128,
152
+ # We use this argument because the texts in our dataset are lists
153
+ # of words (with a label for each word).
154
+ # is_split_into_words=True
155
+ )
156
+
157
+ text_sentence = tokenize(text)
158
+ return tokenized_inputs, text_sentence
159
+
160
+ def _forward(self, inputs):
161
+ inputs, text_sentence = inputs
162
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
163
+ self.model.device
164
+ )
165
+ with torch.no_grad():
166
+ outputs = self.model(input_ids)
167
+ return outputs, text_sentence
168
+
169
+ def postprocess(self, outputs, **kwargs):
170
+ # postprocess the outputs here, for example, convert predictions to labels
171
+ # outputs = ... # some processing here
172
+
173
+ outputs, text_sentence = outputs
174
+ try:
175
+ _, tokens_result = outputs[0], outputs[1]
176
+ except:
177
+ tokens_result = outputs[0]
178
+
179
+ tokens_result = np.argmax(
180
+ tokens_result["logits"].detach().cpu().numpy(), axis=2
181
+ )[0]
182
+
183
+ words_list, preds_list = realign(
184
+ text_sentence,
185
+ tokens_result,
186
+ self.tokenizer,
187
+ id2label,
188
+ )
189
+
190
+ entities = get_entities(words_list, preds_list)
191
+ # print('*'*20, 'Result:', entities)
192
+
193
+ return [entities]
194
+
195
+ # def postprocess(self, outputs, **kwargs):
196
+ #
197
+ # # Extract and process logits
198
+ # outputs, inputs = outputs[0], outputs[1]
199
+ #
200
+ # token_logits, sequence_logits = outputs[0], outputs[1]
201
+ #
202
+ # token_logits = token_logits.logits.detach().cpu().numpy()
203
+ # sequence_logits = sequence_logits.logits.detach().cpu().numpy()
204
+ #
205
+ # text_sentences = [
206
+ # self.tokenizer.convert_ids_to_tokens(input_ids)
207
+ # for input_ids in inputs["input_ids"].detach().cpu().numpy()
208
+ # ]
209
+ #
210
+ # sequence_preds = np.argmax(token_logits, axis=-1)
211
+ # token_preds = np.argmax(sequence_logits, axis=1)
212
+ #
213
+ # # sequence_preds = torch.argmax(sequence_logits, dim=-1)
214
+ # # token_preds = torch.argmax(token_logits, dim=-1)
215
+ #
216
+ # preds_list = [[] for _ in range(token_preds.shape[0])]
217
+ # words_list = [[] for _ in range(token_preds.shape[0])]
218
+ #
219
+ # for idx_sentence, item in enumerate(zip(text_sentences, token_preds)):
220
+ # text_sentence, out_label_preds = item
221
+ # word_ids = self.tokenizer(
222
+ # text_sentence, is_split_into_words=True
223
+ # ).word_ids()
224
+ # for idx, word in enumerate(text_sentence):
225
+ # beginning_index = word_ids.index(idx)
226
+ #
227
+ # try:
228
+ # preds_list[idx_sentence].append(
229
+ # id2label[out_label_preds[beginning_index]]
230
+ # )
231
+ # except BaseException: # the sentence was longer then max_length
232
+ # preds_list[idx_sentence].append("O")
233
+ # words_list[idx_sentence].append(word)
234
+ #
235
+ # import pdb
236
+ #
237
+ # pdb.set_trace()
238
+ # return {
239
+ # "sequence_classification": sequence_preds.cpu().numpy(),
240
+ # "token_classification": token_preds.cpu().numpy(),
241
+ # }