mirbostani commited on
Commit
f3afc88
1 Parent(s): 2b91529

Upload run_triviaqa.py

Browse files
Files changed (1) hide show
  1. run_triviaqa.py +888 -0
run_triviaqa.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17
+
18
+
19
+ import argparse
20
+ import glob
21
+ import logging
22
+ import os
23
+ import random
24
+ import timeit
25
+ import json
26
+
27
+ import numpy as np
28
+ import torch
29
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
30
+ from torch.utils.data.distributed import DistributedSampler
31
+ from tqdm import tqdm, trange
32
+
33
+ import transformers
34
+ from transformers import (
35
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
36
+ WEIGHTS_NAME,
37
+ AdamW,
38
+ AutoConfig,
39
+ AutoModelForQuestionAnswering,
40
+ AutoTokenizer,
41
+ get_linear_schedule_with_warmup,
42
+ squad_convert_examples_to_features,
43
+ )
44
+ from transformers.data.metrics.squad_metrics import (
45
+ compute_predictions_log_probs,
46
+ compute_predictions_logits,
47
+ squad_evaluate,
48
+ )
49
+ from transformers.data.processors.squad import SquadExample, SquadResult, SquadProcessor, SquadV1Processor, SquadV2Processor
50
+ from transformers.data.processors.utils import DataProcessor
51
+ from transformers.trainer_utils import is_main_process
52
+
53
+
54
+ try:
55
+ from torch.utils.tensorboard import SummaryWriter
56
+ except ImportError:
57
+ from tensorboardX import SummaryWriter
58
+
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
63
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
64
+
65
+ class TriviaQAProcessor(SquadProcessor):
66
+ """
67
+ Processor for the TriviaQA dataset.
68
+
69
+ https://github.com/mandarjoshi90/triviaqa
70
+
71
+ @see transformers/src/transformers/data/processors/squad.py
72
+ """
73
+
74
+ train_file = "squad-triviaqa-wikipedia-train.json" # wikipedia or web
75
+ dev_file = "squad-triviaqa-wikipedia-dev.json" # wikipedia or web
76
+
77
+ def _create_examples(self, input_data, set_type):
78
+ is_training = set_type == "train"
79
+ examples = []
80
+ meta = {
81
+ "has_answer": 0,
82
+ "has_no_answer": 0
83
+ }
84
+ for entry_id, entry in tqdm(enumerate(input_data)):
85
+ # TrivaQA entries do not have entry["title"]
86
+ title = str(entry_id)
87
+ for paragraph in entry["paragraphs"]:
88
+ context_text = paragraph["context"]
89
+ for qa in paragraph["qas"]:
90
+ qas_id = qa["id"]
91
+ question_text = qa["question"]
92
+ start_position_character = None
93
+ answer_text = None
94
+ answers = []
95
+
96
+ # Some example fields in TriviaQA are empty (e.g. "answers": [])
97
+ has_answers = len(qa["answers"]) > 0
98
+
99
+ # ignore questions with no answer
100
+ if has_answers:
101
+ is_impossible = qa.get("is_impossible", False)
102
+ if not is_impossible:
103
+ if is_training:
104
+ answer = qa["answers"][0]
105
+ answer_text = answer["text"]
106
+ start_position_character = answer["answer_start"]
107
+ else:
108
+ answers = qa["answers"]
109
+
110
+ example = SquadExample(
111
+ qas_id=qas_id,
112
+ question_text=question_text,
113
+ context_text=context_text,
114
+ answer_text=answer_text,
115
+ start_position_character=start_position_character,
116
+ title=title,
117
+ is_impossible=is_impossible,
118
+ answers=answers,
119
+ )
120
+ examples.append(example)
121
+ meta["has_answer"] += 1
122
+ else:
123
+ # print(">>> {} has no answer".format(qa["id"]))
124
+ meta["has_no_answer"] += 1
125
+
126
+ print(json.dumps(meta, indent=4))
127
+ return examples
128
+
129
+
130
+
131
+ def set_seed(args):
132
+ random.seed(args.seed)
133
+ np.random.seed(args.seed)
134
+ torch.manual_seed(args.seed)
135
+ if args.n_gpu > 0:
136
+ torch.cuda.manual_seed_all(args.seed)
137
+
138
+
139
+ def to_list(tensor):
140
+ return tensor.detach().cpu().tolist()
141
+
142
+
143
+ def train(args, train_dataset, model, tokenizer):
144
+ """Train the model"""
145
+ if args.local_rank in [-1, 0]:
146
+ tb_writer = SummaryWriter()
147
+
148
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
149
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
150
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
151
+
152
+ if args.max_steps > 0:
153
+ t_total = args.max_steps
154
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
155
+ else:
156
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
157
+
158
+ # Prepare optimizer and schedule (linear warmup and decay)
159
+ no_decay = ["bias", "LayerNorm.weight"]
160
+ optimizer_grouped_parameters = [
161
+ {
162
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
163
+ "weight_decay": args.weight_decay,
164
+ },
165
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
166
+ ]
167
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
168
+ scheduler = get_linear_schedule_with_warmup(
169
+ optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
170
+ )
171
+
172
+ # Check if saved optimizer or scheduler states exist
173
+ if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
174
+ os.path.join(args.model_name_or_path, "scheduler.pt")
175
+ ):
176
+ # Load in optimizer and scheduler states
177
+ optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
178
+ scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
179
+
180
+ if args.fp16:
181
+ try:
182
+ from apex import amp
183
+ except ImportError:
184
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
185
+
186
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
187
+
188
+ # multi-gpu training (should be after apex fp16 initialization)
189
+ if args.n_gpu > 1:
190
+ model = torch.nn.DataParallel(model)
191
+
192
+ # Distributed training (should be after apex fp16 initialization)
193
+ if args.local_rank != -1:
194
+ model = torch.nn.parallel.DistributedDataParallel(
195
+ model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
196
+ )
197
+
198
+ # Train!
199
+ logger.info("***** Running training *****")
200
+ logger.info(" Num examples = %d", len(train_dataset))
201
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
202
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
203
+ logger.info(
204
+ " Total train batch size (w. parallel, distributed & accumulation) = %d",
205
+ args.train_batch_size
206
+ * args.gradient_accumulation_steps
207
+ * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
208
+ )
209
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
210
+ logger.info(" Total optimization steps = %d", t_total)
211
+
212
+ global_step = 1
213
+ epochs_trained = 0
214
+ steps_trained_in_current_epoch = 0
215
+ # Check if continuing training from a checkpoint
216
+ if os.path.exists(args.model_name_or_path):
217
+ try:
218
+ # set global_step to gobal_step of last saved checkpoint from model path
219
+ checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
220
+ global_step = int(checkpoint_suffix)
221
+ epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
222
+ steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
223
+
224
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
225
+ logger.info(" Continuing training from epoch %d", epochs_trained)
226
+ logger.info(" Continuing training from global step %d", global_step)
227
+ logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
228
+ except ValueError:
229
+ logger.info(" Starting fine-tuning.")
230
+
231
+ tr_loss, logging_loss = 0.0, 0.0
232
+ model.zero_grad()
233
+ train_iterator = trange(
234
+ epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
235
+ )
236
+ # Added here for reproductibility
237
+ set_seed(args)
238
+
239
+ for _ in train_iterator:
240
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
241
+ for step, batch in enumerate(epoch_iterator):
242
+
243
+ # Skip past any already trained steps if resuming training
244
+ if steps_trained_in_current_epoch > 0:
245
+ steps_trained_in_current_epoch -= 1
246
+ continue
247
+
248
+ model.train()
249
+ batch = tuple(t.to(args.device) for t in batch)
250
+
251
+ inputs = {
252
+ "input_ids": batch[0],
253
+ "attention_mask": batch[1],
254
+ "token_type_ids": batch[2],
255
+ "start_positions": batch[3],
256
+ "end_positions": batch[4],
257
+ }
258
+
259
+ if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
260
+ del inputs["token_type_ids"]
261
+
262
+ if args.model_type in ["xlnet", "xlm"]:
263
+ inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
264
+ if args.version_2_with_negative:
265
+ inputs.update({"is_impossible": batch[7]})
266
+ if hasattr(model, "config") and hasattr(model.config, "lang2id"):
267
+ inputs.update(
268
+ {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
269
+ )
270
+
271
+ outputs = model(**inputs)
272
+ # model outputs are always tuple in transformers (see doc)
273
+ loss = outputs[0]
274
+
275
+ if args.n_gpu > 1:
276
+ loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
277
+ if args.gradient_accumulation_steps > 1:
278
+ loss = loss / args.gradient_accumulation_steps
279
+
280
+ if args.fp16:
281
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
282
+ scaled_loss.backward()
283
+ else:
284
+ loss.backward()
285
+
286
+ tr_loss += loss.item()
287
+ if (step + 1) % args.gradient_accumulation_steps == 0:
288
+ if args.fp16:
289
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
290
+ else:
291
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
292
+
293
+ optimizer.step()
294
+ scheduler.step() # Update learning rate schedule
295
+ model.zero_grad()
296
+ global_step += 1
297
+
298
+ # Log metrics
299
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
300
+ # Only evaluate when single GPU otherwise metrics may not average well
301
+ if args.local_rank == -1 and args.evaluate_during_training:
302
+ results = evaluate(args, model, tokenizer)
303
+ for key, value in results.items():
304
+ tb_writer.add_scalar("eval_{}".format(key), value, global_step)
305
+ tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
306
+ tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
307
+ logging_loss = tr_loss
308
+
309
+ # Save model checkpoint
310
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
311
+ output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
312
+ # Take care of distributed/parallel training
313
+ model_to_save = model.module if hasattr(model, "module") else model
314
+ model_to_save.save_pretrained(output_dir)
315
+ tokenizer.save_pretrained(output_dir)
316
+
317
+ torch.save(args, os.path.join(output_dir, "training_args.bin"))
318
+ logger.info("Saving model checkpoint to %s", output_dir)
319
+
320
+ torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
321
+ torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
322
+ logger.info("Saving optimizer and scheduler states to %s", output_dir)
323
+
324
+ if args.max_steps > 0 and global_step > args.max_steps:
325
+ epoch_iterator.close()
326
+ break
327
+ if args.max_steps > 0 and global_step > args.max_steps:
328
+ train_iterator.close()
329
+ break
330
+
331
+ if args.local_rank in [-1, 0]:
332
+ tb_writer.close()
333
+
334
+ return global_step, tr_loss / global_step
335
+
336
+
337
+ def evaluate(args, model, tokenizer, prefix=""):
338
+ dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
339
+
340
+ if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
341
+ os.makedirs(args.output_dir)
342
+
343
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
344
+
345
+ # Note that DistributedSampler samples randomly
346
+ eval_sampler = SequentialSampler(dataset)
347
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
348
+
349
+ # multi-gpu evaluate
350
+ if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
351
+ model = torch.nn.DataParallel(model)
352
+
353
+ # Eval!
354
+ logger.info("***** Running evaluation {} *****".format(prefix))
355
+ logger.info(" Num examples = %d", len(dataset))
356
+ logger.info(" Batch size = %d", args.eval_batch_size)
357
+
358
+ all_results = []
359
+ start_time = timeit.default_timer()
360
+
361
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
362
+ model.eval()
363
+ batch = tuple(t.to(args.device) for t in batch)
364
+
365
+ with torch.no_grad():
366
+ inputs = {
367
+ "input_ids": batch[0],
368
+ "attention_mask": batch[1],
369
+ "token_type_ids": batch[2],
370
+ }
371
+
372
+ if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
373
+ del inputs["token_type_ids"]
374
+
375
+ feature_indices = batch[3]
376
+
377
+ # XLNet and XLM use more arguments for their predictions
378
+ if args.model_type in ["xlnet", "xlm"]:
379
+ inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
380
+ # for lang_id-sensitive xlm models
381
+ if hasattr(model, "config") and hasattr(model.config, "lang2id"):
382
+ inputs.update(
383
+ {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
384
+ )
385
+ outputs = model(**inputs)
386
+
387
+ for i, feature_index in enumerate(feature_indices):
388
+ eval_feature = features[feature_index.item()]
389
+ unique_id = int(eval_feature.unique_id)
390
+
391
+ output = [to_list(output[i]) for output in outputs.to_tuple()]
392
+
393
+ # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
394
+ # models only use two.
395
+ if len(output) >= 5:
396
+ start_logits = output[0]
397
+ start_top_index = output[1]
398
+ end_logits = output[2]
399
+ end_top_index = output[3]
400
+ cls_logits = output[4]
401
+
402
+ result = SquadResult(
403
+ unique_id,
404
+ start_logits,
405
+ end_logits,
406
+ start_top_index=start_top_index,
407
+ end_top_index=end_top_index,
408
+ cls_logits=cls_logits,
409
+ )
410
+
411
+ else:
412
+ start_logits, end_logits = output
413
+ result = SquadResult(unique_id, start_logits, end_logits)
414
+
415
+ all_results.append(result)
416
+
417
+ evalTime = timeit.default_timer() - start_time
418
+ logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
419
+
420
+ # Compute predictions
421
+ output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
422
+ output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
423
+
424
+ if args.version_2_with_negative:
425
+ output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
426
+ else:
427
+ output_null_log_odds_file = None
428
+
429
+ # XLNet and XLM use a more complex post-processing procedure
430
+ if args.model_type in ["xlnet", "xlm"]:
431
+ start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
432
+ end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
433
+
434
+ predictions = compute_predictions_log_probs(
435
+ examples,
436
+ features,
437
+ all_results,
438
+ args.n_best_size,
439
+ args.max_answer_length,
440
+ output_prediction_file,
441
+ output_nbest_file,
442
+ output_null_log_odds_file,
443
+ start_n_top,
444
+ end_n_top,
445
+ args.version_2_with_negative,
446
+ tokenizer,
447
+ args.verbose_logging,
448
+ )
449
+ else:
450
+ predictions = compute_predictions_logits(
451
+ examples,
452
+ features,
453
+ all_results,
454
+ args.n_best_size,
455
+ args.max_answer_length,
456
+ args.do_lower_case,
457
+ output_prediction_file,
458
+ output_nbest_file,
459
+ output_null_log_odds_file,
460
+ args.verbose_logging,
461
+ args.version_2_with_negative,
462
+ args.null_score_diff_threshold,
463
+ tokenizer,
464
+ )
465
+
466
+ # Compute the F1 and exact scores.
467
+ results = squad_evaluate(examples, predictions)
468
+ return results
469
+
470
+
471
+ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
472
+ if args.local_rank not in [-1, 0] and not evaluate:
473
+ # Make sure only the first process in distributed training process the dataset, and the others will use the cache
474
+ torch.distributed.barrier()
475
+
476
+ # Load data features from cache or dataset file
477
+ input_dir = args.data_dir if args.data_dir else "."
478
+ cached_features_file = os.path.join(
479
+ input_dir,
480
+ "cached_{}_{}_{}".format(
481
+ "dev" if evaluate else "train",
482
+ list(filter(None, args.model_name_or_path.split("/"))).pop(),
483
+ str(args.max_seq_length),
484
+ ),
485
+ )
486
+
487
+ # Init features and dataset from cache if it exists
488
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
489
+ logger.info("Loading features from cached file %s", cached_features_file)
490
+ features_and_dataset = torch.load(cached_features_file)
491
+ features, dataset, examples = (
492
+ features_and_dataset["features"],
493
+ features_and_dataset["dataset"],
494
+ features_and_dataset["examples"],
495
+ )
496
+ else:
497
+ logger.info("Creating features from dataset file at %s", input_dir)
498
+
499
+ if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
500
+ raise NotImplementedError()
501
+ else:
502
+ processor = TriviaQAProcessor()
503
+ if evaluate:
504
+ examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
505
+ else:
506
+ examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
507
+
508
+ features, dataset = squad_convert_examples_to_features(
509
+ examples=examples,
510
+ tokenizer=tokenizer,
511
+ max_seq_length=args.max_seq_length,
512
+ doc_stride=args.doc_stride,
513
+ max_query_length=args.max_query_length,
514
+ is_training=not evaluate,
515
+ return_dataset="pt",
516
+ threads=args.threads,
517
+ )
518
+
519
+ if args.local_rank in [-1, 0]:
520
+ logger.info("Saving features into cached file %s", cached_features_file)
521
+ torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
522
+
523
+ if args.local_rank == 0 and not evaluate:
524
+ # Make sure only the first process in distributed training process the dataset, and the others will use the cache
525
+ torch.distributed.barrier()
526
+
527
+ if output_examples:
528
+ return dataset, examples, features
529
+ return dataset
530
+
531
+
532
+ def main():
533
+ parser = argparse.ArgumentParser()
534
+
535
+ # Required parameters
536
+ parser.add_argument(
537
+ "--model_type",
538
+ default=None,
539
+ type=str,
540
+ required=True,
541
+ help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
542
+ )
543
+ parser.add_argument(
544
+ "--model_name_or_path",
545
+ default=None,
546
+ type=str,
547
+ required=True,
548
+ help="Path to pretrained model or model identifier from huggingface.co/models",
549
+ )
550
+ parser.add_argument(
551
+ "--output_dir",
552
+ default=None,
553
+ type=str,
554
+ required=True,
555
+ help="The output directory where the model checkpoints and predictions will be written.",
556
+ )
557
+
558
+ # Other parameters
559
+ parser.add_argument(
560
+ "--data_dir",
561
+ default=None,
562
+ type=str,
563
+ help="The input data dir. Should contain the .json files for the task."
564
+ + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
565
+ )
566
+ parser.add_argument(
567
+ "--train_file",
568
+ default=None,
569
+ type=str,
570
+ help="The input training file. If a data dir is specified, will look for the file there"
571
+ + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
572
+ )
573
+ parser.add_argument(
574
+ "--predict_file",
575
+ default=None,
576
+ type=str,
577
+ help="The input evaluation file. If a data dir is specified, will look for the file there"
578
+ + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
579
+ )
580
+ parser.add_argument(
581
+ "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
582
+ )
583
+ parser.add_argument(
584
+ "--tokenizer_name",
585
+ default="",
586
+ type=str,
587
+ help="Pretrained tokenizer name or path if not the same as model_name",
588
+ )
589
+ parser.add_argument(
590
+ "--cache_dir",
591
+ default="",
592
+ type=str,
593
+ help="Where do you want to store the pre-trained models downloaded from huggingface.co",
594
+ )
595
+
596
+ parser.add_argument(
597
+ "--version_2_with_negative",
598
+ action="store_true",
599
+ help="If true, the SQuAD examples contain some that do not have an answer.",
600
+ )
601
+ parser.add_argument(
602
+ "--null_score_diff_threshold",
603
+ type=float,
604
+ default=0.0,
605
+ help="If null_score - best_non_null is greater than the threshold predict null.",
606
+ )
607
+
608
+ parser.add_argument(
609
+ "--max_seq_length",
610
+ default=384,
611
+ type=int,
612
+ help="The maximum total input sequence length after WordPiece tokenization. Sequences "
613
+ "longer than this will be truncated, and sequences shorter than this will be padded.",
614
+ )
615
+ parser.add_argument(
616
+ "--doc_stride",
617
+ default=128,
618
+ type=int,
619
+ help="When splitting up a long document into chunks, how much stride to take between chunks.",
620
+ )
621
+ parser.add_argument(
622
+ "--max_query_length",
623
+ default=64,
624
+ type=int,
625
+ help="The maximum number of tokens for the question. Questions longer than this will "
626
+ "be truncated to this length.",
627
+ )
628
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
629
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
630
+ parser.add_argument(
631
+ "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
632
+ )
633
+ parser.add_argument(
634
+ "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
635
+ )
636
+
637
+ parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
638
+ parser.add_argument(
639
+ "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
640
+ )
641
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
642
+ parser.add_argument(
643
+ "--gradient_accumulation_steps",
644
+ type=int,
645
+ default=1,
646
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
647
+ )
648
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
649
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
650
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
651
+ parser.add_argument(
652
+ "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
653
+ )
654
+ parser.add_argument(
655
+ "--max_steps",
656
+ default=-1,
657
+ type=int,
658
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
659
+ )
660
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
661
+ parser.add_argument(
662
+ "--n_best_size",
663
+ default=20,
664
+ type=int,
665
+ help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
666
+ )
667
+ parser.add_argument(
668
+ "--max_answer_length",
669
+ default=30,
670
+ type=int,
671
+ help="The maximum length of an answer that can be generated. This is needed because the start "
672
+ "and end predictions are not conditioned on one another.",
673
+ )
674
+ parser.add_argument(
675
+ "--verbose_logging",
676
+ action="store_true",
677
+ help="If true, all of the warnings related to data processing will be printed. "
678
+ "A number of warnings are expected for a normal SQuAD evaluation.",
679
+ )
680
+ parser.add_argument(
681
+ "--lang_id",
682
+ default=0,
683
+ type=int,
684
+ help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
685
+ )
686
+
687
+ parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
688
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
689
+ parser.add_argument(
690
+ "--eval_all_checkpoints",
691
+ action="store_true",
692
+ help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
693
+ )
694
+ parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
695
+ parser.add_argument(
696
+ "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
697
+ )
698
+ parser.add_argument(
699
+ "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
700
+ )
701
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
702
+
703
+ parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
704
+ parser.add_argument(
705
+ "--fp16",
706
+ action="store_true",
707
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
708
+ )
709
+ parser.add_argument(
710
+ "--fp16_opt_level",
711
+ type=str,
712
+ default="O1",
713
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
714
+ "See details at https://nvidia.github.io/apex/amp.html",
715
+ )
716
+ parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
717
+ parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
718
+
719
+ parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
720
+ args = parser.parse_args()
721
+
722
+ if args.doc_stride >= args.max_seq_length - args.max_query_length:
723
+ logger.warning(
724
+ "WARNING - You've set a doc stride which may be superior to the document length in some "
725
+ "examples. This could result in errors when building features from the examples. Please reduce the doc "
726
+ "stride or increase the maximum length to ensure the features are correctly built."
727
+ )
728
+
729
+ if (
730
+ os.path.exists(args.output_dir)
731
+ and os.listdir(args.output_dir)
732
+ and args.do_train
733
+ and not args.overwrite_output_dir
734
+ ):
735
+ raise ValueError(
736
+ "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
737
+ args.output_dir
738
+ )
739
+ )
740
+
741
+ # Setup distant debugging if needed
742
+ if args.server_ip and args.server_port:
743
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
744
+ import ptvsd
745
+
746
+ print("Waiting for debugger attach")
747
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
748
+ ptvsd.wait_for_attach()
749
+
750
+ # Setup CUDA, GPU & distributed training
751
+ if args.local_rank == -1 or args.no_cuda:
752
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
753
+ args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
754
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
755
+ torch.cuda.set_device(args.local_rank)
756
+ device = torch.device("cuda", args.local_rank)
757
+ torch.distributed.init_process_group(backend="nccl")
758
+ args.n_gpu = 1
759
+ args.device = device
760
+
761
+ # Setup logging
762
+ logging.basicConfig(
763
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
764
+ datefmt="%m/%d/%Y %H:%M:%S",
765
+ level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
766
+ )
767
+ logger.warning(
768
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
769
+ args.local_rank,
770
+ device,
771
+ args.n_gpu,
772
+ bool(args.local_rank != -1),
773
+ args.fp16,
774
+ )
775
+ # Set the verbosity to info of the Transformers logger (on main process only):
776
+ if is_main_process(args.local_rank):
777
+ transformers.utils.logging.set_verbosity_info()
778
+ transformers.utils.logging.enable_default_handler()
779
+ transformers.utils.logging.enable_explicit_format()
780
+ # Set seed
781
+ set_seed(args)
782
+
783
+ # Load pretrained model and tokenizer
784
+ if args.local_rank not in [-1, 0]:
785
+ # Make sure only the first process in distributed training will download model & vocab
786
+ torch.distributed.barrier()
787
+
788
+ args.model_type = args.model_type.lower()
789
+ config = AutoConfig.from_pretrained(
790
+ args.config_name if args.config_name else args.model_name_or_path,
791
+ cache_dir=args.cache_dir if args.cache_dir else None,
792
+ )
793
+ tokenizer = AutoTokenizer.from_pretrained(
794
+ args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
795
+ do_lower_case=args.do_lower_case,
796
+ cache_dir=args.cache_dir if args.cache_dir else None,
797
+ use_fast=False, # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
798
+ )
799
+ model = AutoModelForQuestionAnswering.from_pretrained(
800
+ args.model_name_or_path,
801
+ from_tf=bool(".ckpt" in args.model_name_or_path),
802
+ config=config,
803
+ cache_dir=args.cache_dir if args.cache_dir else None,
804
+ )
805
+
806
+ if args.local_rank == 0:
807
+ # Make sure only the first process in distributed training will download model & vocab
808
+ torch.distributed.barrier()
809
+
810
+ model.to(args.device)
811
+
812
+ logger.info("Training/evaluation parameters %s", args)
813
+
814
+ # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
815
+ # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
816
+ # remove the need for this code, but it is still valid.
817
+ if args.fp16:
818
+ try:
819
+ import apex
820
+
821
+ apex.amp.register_half_function(torch, "einsum")
822
+ except ImportError:
823
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
824
+
825
+ # Training
826
+ if args.do_train:
827
+ train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
828
+ global_step, tr_loss = train(args, train_dataset, model, tokenizer)
829
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
830
+
831
+ # Save the trained model and the tokenizer
832
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
833
+ logger.info("Saving model checkpoint to %s", args.output_dir)
834
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
835
+ # They can then be reloaded using `from_pretrained()`
836
+ # Take care of distributed/parallel training
837
+ model_to_save = model.module if hasattr(model, "module") else model
838
+ model_to_save.save_pretrained(args.output_dir)
839
+ tokenizer.save_pretrained(args.output_dir)
840
+
841
+ # Good practice: save your training arguments together with the trained model
842
+ torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
843
+
844
+ # Load a trained model and vocabulary that you have fine-tuned
845
+ model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True)
846
+
847
+ # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
848
+ # So we use use_fast=False here for now until Fast-tokenizer-compatible-examples are out
849
+ tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, use_fast=False)
850
+ model.to(args.device)
851
+
852
+ # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
853
+ results = {}
854
+ if args.do_eval and args.local_rank in [-1, 0]:
855
+ if args.do_train:
856
+ logger.info("Loading checkpoints saved during training for evaluation")
857
+ checkpoints = [args.output_dir]
858
+ if args.eval_all_checkpoints:
859
+ checkpoints = list(
860
+ os.path.dirname(c)
861
+ for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
862
+ )
863
+
864
+ else:
865
+ logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
866
+ checkpoints = [args.model_name_or_path]
867
+
868
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
869
+
870
+ for checkpoint in checkpoints:
871
+ # Reload the model
872
+ global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
873
+ model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) # , force_download=True)
874
+ model.to(args.device)
875
+
876
+ # Evaluate
877
+ result = evaluate(args, model, tokenizer, prefix=global_step)
878
+
879
+ result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
880
+ results.update(result)
881
+
882
+ logger.info("Results: {}".format(results))
883
+
884
+ return results
885
+
886
+
887
+ if __name__ == "__main__":
888
+ main()