|
|
|
|
| import numpy as np
|
| import os
|
| import ntpath
|
| import time
|
| import glob
|
| from scipy.misc import imresize
|
| import torchvision.utils as vutils
|
| from operator import itemgetter
|
| from tensorboardX import SummaryWriter
|
|
|
|
|
| class Visualizer():
|
| def __init__(self, checkpoints_dir, name):
|
| self.win_size = 256
|
| self.name = name
|
| self.saved = False
|
| self.checkpoints_dir = checkpoints_dir
|
| self.ncols = 4
|
|
|
|
|
| for filename in glob.glob(self.checkpoints_dir+"/events*"):
|
| os.remove(filename)
|
| self.writer = SummaryWriter(checkpoints_dir)
|
|
|
| def reset(self):
|
| self.saved = False
|
|
|
|
|
| def image_summary(self, mode, epoch, images):
|
| images = vutils.make_grid(images, normalize=True, scale_each=True)
|
| self.writer.add_image('{}/Image'.format(mode), images, epoch)
|
|
|
|
|
| def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20):
|
| for i, el in enumerate(text):
|
| if not gt:
|
| idx = el.nonzero().squeeze() + 1
|
| else:
|
| idx = el
|
|
|
| words_list = itemgetter(*idx)(vocabulary)
|
|
|
| if len(words_list) <= max_length:
|
| self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
|
| ', '.join(filter(lambda x: x != '<pad>', words_list)), epoch)
|
| else:
|
| self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
|
| 'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch)
|
|
|
|
|
| def scalar_summary(self, mode, epoch, **args):
|
| for k, v in args.items():
|
| self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch)
|
|
|
| self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir))
|
|
|
| def histo_summary(self, model, step):
|
| """Log a histogram of the tensor of values."""
|
|
|
| for name, param in model.named_parameters():
|
| self.writer.add_histogram(name, param, step)
|
|
|
| def close(self):
|
| self.writer.close()
|
|
|