|
|
|
|
| import pickle
|
| from tqdm import tqdm
|
| import os
|
| import numpy as np
|
| from PIL import Image
|
| import argparse
|
| import lmdb
|
| from torchvision import transforms
|
|
|
|
|
| MAX_SIZE = 1e12
|
|
|
|
|
| def load_and_resize(root, path, imscale):
|
|
|
| transf_list = []
|
| transf_list.append(transforms.Resize(imscale))
|
| transf_list.append(transforms.CenterCrop(imscale))
|
| transform = transforms.Compose(transf_list)
|
|
|
| img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB')
|
| img = transform(img)
|
|
|
| return img
|
|
|
|
|
| def main(args):
|
|
|
| parts = {}
|
| datasets = {}
|
| imname2pos = {'train': {}, 'val': {}, 'test': {}}
|
| for split in ['train', 'val', 'test']:
|
| datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb'))
|
|
|
| parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE))
|
| with parts[split].begin() as txn:
|
| present_entries = [key for key, _ in txn.cursor()]
|
| j = 0
|
| for i, entry in tqdm(enumerate(datasets[split])):
|
| impaths = entry['images'][0:5]
|
|
|
| for n, p in enumerate(impaths):
|
| if n == args.maxnumims:
|
| break
|
| if p.encode() not in present_entries:
|
| im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale)
|
| im = np.array(im).astype(np.uint8)
|
| with parts[split].begin(write=True) as txn:
|
| txn.put(p.encode(), im)
|
| imname2pos[split][p] = j
|
| j += 1
|
| pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb'))
|
|
|
|
|
| def test(args):
|
|
|
| imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb'))
|
| paths = imname2pos['val']
|
|
|
| for k, v in paths.items():
|
| path = k
|
| break
|
| image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True,
|
| lock=False, readahead=False, meminit=False)
|
| with image_file.begin(write=False) as txn:
|
| image = txn.get(path.encode())
|
| image = np.fromstring(image, dtype=np.uint8)
|
| image = np.reshape(image, (args.imscale, args.imscale, 3))
|
| image = Image.fromarray(image.astype('uint8'), 'RGB')
|
| print (np.shape(image))
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('--root', type=str, default='path/to/recipe1m',
|
| help='path to the recipe1m dataset')
|
| parser.add_argument('--save_dir', type=str, default='../data',
|
| help='path where the lmdbs will be saved')
|
| parser.add_argument('--imscale', type=int, default=256,
|
| help='size of images (will be rescaled and center cropped)')
|
| parser.add_argument('--maxnumims', type=int, default=5,
|
| help='maximum number of images to allow for each sample')
|
| parser.add_argument('--suff', type=str, default='',
|
| help='id of the vocabulary to use')
|
| parser.add_argument('--test_only', dest='test_only', action='store_true')
|
| parser.set_defaults(test_only=False)
|
| args = parser.parse_args()
|
|
|
| if not args.test_only:
|
| main(args)
|
| test(args)
|
|
|