|
|
|
|
| import nltk
|
| import pickle
|
| import argparse
|
| from collections import Counter
|
| import json
|
| import os
|
| from tqdm import *
|
| import numpy as np
|
| import re
|
|
|
|
|
| class Vocabulary(object):
|
| """Simple vocabulary wrapper."""
|
| def __init__(self):
|
| self.word2idx = {}
|
| self.idx2word = {}
|
| self.idx = 0
|
|
|
| def add_word(self, word, idx=None):
|
| if idx is None:
|
| if not word in self.word2idx:
|
| self.word2idx[word] = self.idx
|
| self.idx2word[self.idx] = word
|
| self.idx += 1
|
| return self.idx
|
| else:
|
| if not word in self.word2idx:
|
| self.word2idx[word] = idx
|
| if idx in self.idx2word.keys():
|
| self.idx2word[idx].append(word)
|
| else:
|
| self.idx2word[idx] = [word]
|
|
|
| return idx
|
|
|
| def __call__(self, word):
|
| if not word in self.word2idx:
|
| return self.word2idx['<pad>']
|
| return self.word2idx[word]
|
|
|
| def __len__(self):
|
| return len(self.idx2word)
|
|
|
|
|
| def get_ingredient(det_ingr, replace_dict):
|
| det_ingr_undrs = det_ingr['text'].lower()
|
| det_ingr_undrs = ''.join(i for i in det_ingr_undrs if not i.isdigit())
|
|
|
| for rep, char_list in replace_dict.items():
|
| for c_ in char_list:
|
| if c_ in det_ingr_undrs:
|
| det_ingr_undrs = det_ingr_undrs.replace(c_, rep)
|
| det_ingr_undrs = det_ingr_undrs.strip()
|
| det_ingr_undrs = det_ingr_undrs.replace(' ', '_')
|
|
|
| return det_ingr_undrs
|
|
|
|
|
| def get_instruction(instruction, replace_dict, instruction_mode=True):
|
| instruction = instruction.lower()
|
|
|
| for rep, char_list in replace_dict.items():
|
| for c_ in char_list:
|
| if c_ in instruction:
|
| instruction = instruction.replace(c_, rep)
|
| instruction = instruction.strip()
|
|
|
| if len(instruction) > 0 and instruction[0].isdigit() and instruction_mode:
|
| instruction = ''
|
| return instruction
|
|
|
|
|
| def remove_plurals(counter_ingrs, ingr_clusters):
|
| del_ingrs = []
|
|
|
| for k, v in counter_ingrs.items():
|
|
|
| if len(k) == 0:
|
| del_ingrs.append(k)
|
| continue
|
|
|
| gotit = 0
|
| if k[-2:] == 'es':
|
| if k[:-2] in counter_ingrs.keys():
|
| counter_ingrs[k[:-2]] += v
|
| ingr_clusters[k[:-2]].extend(ingr_clusters[k])
|
| del_ingrs.append(k)
|
| gotit = 1
|
|
|
| if k[-1] == 's' and gotit == 0:
|
| if k[:-1] in counter_ingrs.keys():
|
| counter_ingrs[k[:-1]] += v
|
| ingr_clusters[k[:-1]].extend(ingr_clusters[k])
|
| del_ingrs.append(k)
|
| for item in del_ingrs:
|
| del counter_ingrs[item]
|
| del ingr_clusters[item]
|
| return counter_ingrs, ingr_clusters
|
|
|
|
|
| def cluster_ingredients(counter_ingrs):
|
| mydict = dict()
|
| mydict_ingrs = dict()
|
|
|
| for k, v in counter_ingrs.items():
|
|
|
| w1 = k.split('_')[-1]
|
| w2 = k.split('_')[0]
|
| lw = [w1, w2]
|
| if len(k.split('_')) > 1:
|
| w3 = k.split('_')[0] + '_' + k.split('_')[1]
|
| w4 = k.split('_')[-2] + '_' + k.split('_')[-1]
|
|
|
| lw = [w1, w2, w4, w3]
|
|
|
| gotit = 0
|
| for w in lw:
|
| if w in counter_ingrs.keys():
|
|
|
| parts = w.split('_')
|
| if len(parts) > 0:
|
| if parts[0] in counter_ingrs.keys():
|
| w = parts[0]
|
| elif parts[1] in counter_ingrs.keys():
|
| w = parts[1]
|
| if w in mydict.keys():
|
| mydict[w] += v
|
| mydict_ingrs[w].append(k)
|
| else:
|
| mydict[w] = v
|
| mydict_ingrs[w] = [k]
|
| gotit = 1
|
| break
|
| if gotit == 0:
|
| mydict[k] = v
|
| mydict_ingrs[k] = [k]
|
|
|
| return mydict, mydict_ingrs
|
|
|
|
|
| def update_counter(list_, counter_toks, istrain=False):
|
| for sentence in list_:
|
| tokens = nltk.tokenize.word_tokenize(sentence)
|
| if istrain:
|
| counter_toks.update(tokens)
|
|
|
|
|
| def build_vocab_recipe1m(args):
|
| print ("Loading data...")
|
| dets = json.load(open(os.path.join(args.recipe1m_path, 'det_ingrs.json'), 'r'))
|
| layer1 = json.load(open(os.path.join(args.recipe1m_path, 'layer1.json'), 'r'))
|
| layer2 = json.load(open(os.path.join(args.recipe1m_path, 'layer2.json'), 'r'))
|
|
|
| id2im = {}
|
|
|
| for i, entry in enumerate(layer2):
|
| id2im[entry['id']] = i
|
|
|
| print("Loaded data.")
|
| print("Found %d recipes in the dataset." % (len(layer1)))
|
| replace_dict_ingrs = {'and': ['&', "'n"], '': ['%', ',', '.', '#', '[', ']', '!', '?']}
|
| replace_dict_instrs = {'and': ['&', "'n"], '': ['#', '[', ']']}
|
|
|
| idx2ind = {}
|
| for i, entry in enumerate(dets):
|
| idx2ind[entry['id']] = i
|
|
|
| ingrs_file = args.save_path + 'allingrs_count.pkl'
|
| instrs_file = args.save_path + 'allwords_count.pkl'
|
|
|
|
|
|
|
|
|
| if os.path.exists(ingrs_file) and os.path.exists(instrs_file) and not args.forcegen:
|
| print ("loading pre-extracted word counters")
|
| counter_ingrs = pickle.load(open(args.save_path + 'allingrs_count.pkl', 'rb'))
|
| counter_toks = pickle.load(open(args.save_path + 'allwords_count.pkl', 'rb'))
|
| else:
|
| counter_toks = Counter()
|
| counter_ingrs = Counter()
|
| counter_ingrs_raw = Counter()
|
|
|
| for i, entry in tqdm(enumerate(layer1)):
|
|
|
|
|
| instrs = entry['instructions']
|
|
|
| instrs_list = []
|
| ingrs_list = []
|
|
|
|
|
| det_ingrs = dets[idx2ind[entry['id']]]['ingredients']
|
|
|
| valid = dets[idx2ind[entry['id']]]['valid']
|
| det_ingrs_filtered = []
|
|
|
| for j, det_ingr in enumerate(det_ingrs):
|
| if len(det_ingr) > 0 and valid[j]:
|
| det_ingr_undrs = get_ingredient(det_ingr, replace_dict_ingrs)
|
| det_ingrs_filtered.append(det_ingr_undrs)
|
| ingrs_list.append(det_ingr_undrs)
|
|
|
|
|
| acc_len = 0
|
| for instr in instrs:
|
| instr = instr['text']
|
| instr = get_instruction(instr, replace_dict_instrs)
|
| if len(instr) > 0:
|
| instrs_list.append(instr)
|
| acc_len += len(instr)
|
|
|
|
|
| if len(ingrs_list) < args.minnumingrs or len(instrs_list) < args.minnuminstrs \
|
| or len(instrs_list) >= args.maxnuminstrs or len(ingrs_list) >= args.maxnumingrs \
|
| or acc_len < args.minnumwords:
|
| continue
|
|
|
|
|
| update_counter(instrs_list, counter_toks, istrain=entry['partition'] == 'train')
|
| title = nltk.tokenize.word_tokenize(entry['title'].lower())
|
| if entry['partition'] == 'train':
|
| counter_toks.update(title)
|
| if entry['partition'] == 'train':
|
| counter_ingrs.update(ingrs_list)
|
|
|
| pickle.dump(counter_ingrs, open(args.save_path + 'allingrs_count.pkl', 'wb'))
|
| pickle.dump(counter_toks, open(args.save_path + 'allwords_count.pkl', 'wb'))
|
| pickle.dump(counter_ingrs_raw, open(args.save_path + 'allingrs_raw_count.pkl', 'wb'))
|
|
|
|
|
| base_words = ['peppers', 'tomato', 'spinach_leaves', 'turkey_breast', 'lettuce_leaf',
|
| 'chicken_thighs', 'milk_powder', 'bread_crumbs', 'onion_flakes',
|
| 'red_pepper', 'pepper_flakes', 'juice_concentrate', 'cracker_crumbs', 'hot_chili',
|
| 'seasoning_mix', 'dill_weed', 'pepper_sauce', 'sprouts', 'cooking_spray', 'cheese_blend',
|
| 'basil_leaves', 'pineapple_chunks', 'marshmallow', 'chile_powder',
|
| 'cheese_blend', 'corn_kernels', 'tomato_sauce', 'chickens', 'cracker_crust',
|
| 'lemonade_concentrate', 'red_chili', 'mushroom_caps', 'mushroom_cap', 'breaded_chicken',
|
| 'frozen_pineapple', 'pineapple_chunks', 'seasoning_mix', 'seaweed', 'onion_flakes',
|
| 'bouillon_granules', 'lettuce_leaf', 'stuffing_mix', 'parsley_flakes', 'chicken_breast',
|
| 'basil_leaves', 'baguettes', 'green_tea', 'peanut_butter', 'green_onion', 'fresh_cilantro',
|
| 'breaded_chicken', 'hot_pepper', 'dried_lavender', 'white_chocolate',
|
| 'dill_weed', 'cake_mix', 'cheese_spread', 'turkey_breast', 'chucken_thighs', 'basil_leaves',
|
| 'mandarin_orange', 'laurel', 'cabbage_head', 'pistachio', 'cheese_dip',
|
| 'thyme_leave', 'boneless_pork', 'red_pepper', 'onion_dip', 'skinless_chicken', 'dark_chocolate',
|
| 'canned_corn', 'muffin', 'cracker_crust', 'bread_crumbs', 'frozen_broccoli',
|
| 'philadelphia', 'cracker_crust', 'chicken_breast']
|
|
|
| for base_word in base_words:
|
|
|
| if base_word not in counter_ingrs.keys():
|
| counter_ingrs[base_word] = 1
|
|
|
| counter_ingrs, cluster_ingrs = cluster_ingredients(counter_ingrs)
|
| counter_ingrs, cluster_ingrs = remove_plurals(counter_ingrs, cluster_ingrs)
|
|
|
|
|
| words = [word for word, cnt in counter_toks.items() if cnt >= args.threshold_words]
|
| ingrs = {word: cnt for word, cnt in counter_ingrs.items() if cnt >= args.threshold_ingrs}
|
|
|
|
|
|
|
| vocab_toks = Vocabulary()
|
| vocab_toks.add_word('<start>')
|
| vocab_toks.add_word('<end>')
|
| vocab_toks.add_word('<eoi>')
|
|
|
|
|
| for i, word in enumerate(words):
|
| vocab_toks.add_word(word)
|
| vocab_toks.add_word('<pad>')
|
|
|
|
|
|
|
| vocab_ingrs = Vocabulary()
|
| idx = vocab_ingrs.add_word('<end>')
|
|
|
|
|
| for k, _ in ingrs.items():
|
| for ingr in cluster_ingrs[k]:
|
| idx = vocab_ingrs.add_word(ingr, idx)
|
| idx += 1
|
| _ = vocab_ingrs.add_word('<pad>', idx)
|
|
|
| print("Total ingr vocabulary size: {}".format(len(vocab_ingrs)))
|
| print("Total token vocabulary size: {}".format(len(vocab_toks)))
|
|
|
| dataset = {'train': [], 'val': [], 'test': []}
|
|
|
|
|
|
|
|
|
| for i, entry in tqdm(enumerate(layer1)):
|
|
|
|
|
| instrs = entry['instructions']
|
|
|
| instrs_list = []
|
| ingrs_list = []
|
| images_list = []
|
|
|
|
|
| det_ingrs = dets[idx2ind[entry['id']]]['ingredients']
|
| valid = dets[idx2ind[entry['id']]]['valid']
|
| labels = []
|
|
|
| for j, det_ingr in enumerate(det_ingrs):
|
| if len(det_ingr) > 0 and valid[j]:
|
| det_ingr_undrs = get_ingredient(det_ingr, replace_dict_ingrs)
|
| ingrs_list.append(det_ingr_undrs)
|
| label_idx = vocab_ingrs(det_ingr_undrs)
|
| if label_idx is not vocab_ingrs('<pad>') and label_idx not in labels:
|
| labels.append(label_idx)
|
|
|
|
|
| acc_len = 0
|
| for instr in instrs:
|
| instr = instr['text']
|
| instr = get_instruction(instr, replace_dict_instrs)
|
| if len(instr) > 0:
|
| acc_len += len(instr)
|
| instrs_list.append(instr)
|
|
|
|
|
| if len(labels) < args.minnumingrs or len(instrs_list) < args.minnuminstrs \
|
| or len(instrs_list) >= args.maxnuminstrs or len(labels) >= args.maxnumingrs \
|
| or acc_len < args.minnumwords:
|
| continue
|
|
|
| if entry['id'] in id2im.keys():
|
| ims = layer2[id2im[entry['id']]]
|
|
|
|
|
| for im in ims['images']:
|
| images_list.append(im['id'])
|
|
|
|
|
| toks = []
|
|
|
| for instr in instrs_list:
|
| tokens = nltk.tokenize.word_tokenize(instr)
|
| toks.append(tokens)
|
|
|
| title = nltk.tokenize.word_tokenize(entry['title'].lower())
|
|
|
| newentry = {'id': entry['id'], 'instructions': instrs_list, 'tokenized': toks,
|
| 'ingredients': ingrs_list, 'images': images_list, 'title': title}
|
| dataset[entry['partition']].append(newentry)
|
|
|
| print('Dataset size:')
|
| for split in dataset.keys():
|
| print(split, ':', len(dataset[split]))
|
|
|
| return vocab_ingrs, vocab_toks, dataset
|
|
|
|
|
| def main(args):
|
|
|
| vocab_ingrs, vocab_toks, dataset = build_vocab_recipe1m(args)
|
|
|
| with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_ingrs.pkl'), 'wb') as f:
|
| pickle.dump(vocab_ingrs, f)
|
| with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_toks.pkl'), 'wb') as f:
|
| pickle.dump(vocab_toks, f)
|
|
|
| for split in dataset.keys():
|
| with open(os.path.join(args.save_path, args.suff+'recipe1m_' + split + '.pkl'), 'wb') as f:
|
| pickle.dump(dataset[split], f)
|
|
|
|
|
| if __name__ == '__main__':
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('--recipe1m_path', type=str,
|
| default='path/to/recipe1m',
|
| help='recipe1m path')
|
|
|
| parser.add_argument('--save_path', type=str, default='../data/',
|
| help='path for saving vocabulary wrapper')
|
|
|
| parser.add_argument('--suff', type=str, default='')
|
|
|
| parser.add_argument('--threshold_ingrs', type=int, default=10,
|
| help='minimum ingr count threshold')
|
|
|
| parser.add_argument('--threshold_words', type=int, default=10,
|
| help='minimum word count threshold')
|
|
|
| parser.add_argument('--maxnuminstrs', type=int, default=20,
|
| help='max number of instructions (sentences)')
|
|
|
| parser.add_argument('--maxnumingrs', type=int, default=20,
|
| help='max number of ingredients')
|
|
|
| parser.add_argument('--minnuminstrs', type=int, default=2,
|
| help='max number of instructions (sentences)')
|
|
|
| parser.add_argument('--minnumingrs', type=int, default=2,
|
| help='max number of ingredients')
|
|
|
| parser.add_argument('--minnumwords', type=int, default=20,
|
| help='minimum number of characters in recipe')
|
|
|
| parser.add_argument('--forcegen', dest='forcegen', action='store_true')
|
| parser.set_defaults(forcegen=False)
|
|
|
| args = parser.parse_args()
|
| main(args)
|
|
|