|
1 | | -import argparse |
2 | | -import logging |
3 | 1 | import os |
4 | 2 | import pdb |
5 | 3 | import sys |
6 | | -import traceback |
7 | | -import pickle |
8 | 4 | import json |
9 | | -import pandas as pd |
| 5 | +import logging |
| 6 | +import traceback |
10 | 7 | import collections |
| 8 | + |
| 9 | +import argparse |
| 10 | +import pickle |
11 | 11 | from tqdm import tqdm |
12 | | -from embedding import Embedding |
| 12 | + |
13 | 13 | from preprocessor import Preprocessor |
14 | | -from torch.utils.data import DataLoader |
| 14 | + |
15 | 15 |
|
16 | 16 | def main(args): |
17 | | - config_path = os.path.join(args.dest_dir, 'config.json') |
18 | | - logging.info('loading configuration from {}'.format(config_path)) |
| 17 | + config_path = os.path.join(args.dest_dir, "config.json") |
| 18 | + logging.info("loading configuration from {}".format(config_path)) |
19 | 19 | with open(config_path) as f: |
20 | 20 | config = json.load(f) |
21 | | - |
| 21 | + |
22 | 22 | preprocessor = Preprocessor(None) |
23 | | - |
24 | | - logging.info('loading training data from {}'.format(config['train_path'])) |
25 | | - with open(config['train_path'], 'r') as f: |
| 23 | + |
| 24 | + logging.info("loading training data from {}".format(config["train_path"])) |
| 25 | + with open(config["train_path"], "r") as f: |
26 | 26 | train_data = f.readlines() |
27 | | - |
28 | | - logging.info('loading validation data from {}'.format(config['valid_path'])) |
29 | | - with open(config['valid_path'], 'r') as f: |
| 27 | + |
| 28 | + logging.info("loading validation data from {}".format(config["valid_path"])) |
| 29 | + with open(config["valid_path"], "r") as f: |
30 | 30 | valid_data = f.readlines() |
31 | | - |
32 | | - logging.info('loading testing data from {}'.format(config['test_path'])) |
33 | | - with open(config['test_path'], 'r') as f: |
| 31 | + |
| 32 | + logging.info("loading testing data from {}".format(config["test_path"])) |
| 33 | + with open(config["test_path"], "r") as f: |
34 | 34 | test_data = f.readlines() |
35 | | - |
36 | | - logging.info('loading long corpus testing data from {}'.format(config['long_test_path'])) |
37 | | - with open(config['long_test_path'], 'r') as f: |
| 35 | + |
| 36 | + logging.info( |
| 37 | + "loading long corpus testing data from {}".format(config["long_test_path"]) |
| 38 | + ) |
| 39 | + with open(config["long_test_path"], "r") as f: |
38 | 40 | long_test_data = f.readlines() |
39 | | - |
| 41 | + |
40 | 42 | # collect words appear in the data |
41 | | - logging.info('collecting words from training set...') |
| 43 | + logging.info("collecting words from training set...") |
42 | 44 | words = collections.Counter() |
43 | 45 | for data in tqdm(train_data, total=len(train_data)): |
44 | 46 | words.update(data.strip().split()) |
45 | | - logging.info('{} words collected'.format(len(words))) |
46 | | - |
| 47 | + logging.info("{} words collected".format(len(words))) |
| 48 | + |
47 | 49 | # build sorted vocab dictionary (for adaptive_softmax loss later) |
48 | 50 | word_dict = {} |
49 | 51 | counter = 4 |
50 | | - word_dict['<PAD>'] = 0 |
51 | | - word_dict['<UNK>'] = 1 |
52 | | - word_dict['<SOS>'] = 2 |
53 | | - word_dict['<EOS>'] = 3 |
| 52 | + word_dict["<PAD>"] = 0 |
| 53 | + word_dict["<UNK>"] = 1 |
| 54 | + word_dict["<SOS>"] = 2 |
| 55 | + word_dict["<EOS>"] = 3 |
54 | 56 | for word in tqdm(words.most_common()): |
55 | 57 | if word[1] > args.threshold: |
56 | 58 | word_dict[word[0]] = counter |
57 | 59 | counter += 1 |
58 | | - |
59 | | - logging.info('{} words saved'.format(counter)) |
60 | | - |
61 | | - vocab_path = '_{}.pkl'.format(args.threshold).join(config['vocab_path'].split('.pkl')) |
62 | | - word_set_path = '_{}.pkl'.format(args.threshold).join(config['word_set_path'].split('.pkl')) |
63 | | - |
64 | | - with open(vocab_path, 'wb') as fout: |
| 60 | + |
| 61 | + logging.info("{} words saved".format(counter)) |
| 62 | + |
| 63 | + vocab_path = "_{}.pkl".format(args.threshold).join( |
| 64 | + config["vocab_path"].split(".pkl") |
| 65 | + ) |
| 66 | + word_set_path = "_{}.pkl".format(args.threshold).join( |
| 67 | + config["word_set_path"].split(".pkl") |
| 68 | + ) |
| 69 | + |
| 70 | + with open(vocab_path, "wb") as fout: |
65 | 71 | pickle.dump(word_dict, fout) |
66 | | - with open(word_set_path, 'wb') as fout: |
| 72 | + with open(word_set_path, "wb") as fout: |
67 | 73 | pickle.dump(words, fout) |
68 | | - logging.info('Word frequency and vocab saved in {}, {}'.format(word_set_path, vocab_path)) |
| 74 | + logging.info( |
| 75 | + "Word frequency and vocab saved in {}, {}".format(word_set_path, vocab_path) |
| 76 | + ) |
69 | 77 |
|
70 | 78 | # update word dictionary used by preprocessor |
71 | 79 | preprocessor.words_dict = word_dict |
72 | | - |
| 80 | + |
73 | 81 | # train |
74 | | - logging.info('Processing training set from {}'.format(config['train_path'])) |
| 82 | + logging.info("Processing training set from {}".format(config["train_path"])) |
75 | 83 | train = preprocessor.get_dataset(train_data, args.n_workers) |
76 | | - train_pkl_path = os.path.join(args.dest_dir, 'train_{}.pkl'.format(args.threshold)) |
77 | | - logging.info('Saving training set to {}'.format(train_pkl_path)) |
78 | | - with open(train_pkl_path, 'wb') as f: |
| 84 | + train_pkl_path = os.path.join( |
| 85 | + args.dest_dir, "train_{}.pkl".format(args.threshold) |
| 86 | + ) |
| 87 | + logging.info("Saving training set to {}".format(train_pkl_path)) |
| 88 | + with open(train_pkl_path, "wb") as f: |
79 | 89 | pickle.dump(train, f) |
80 | | - |
| 90 | + |
81 | 91 | # valid |
82 | | - logging.info('Processing validation set from {}'.format(config['valid_path'])) |
| 92 | + logging.info("Processing validation set from {}".format(config["valid_path"])) |
83 | 93 | valid = preprocessor.get_dataset(valid_data, args.n_workers) |
84 | | - valid_pkl_path = os.path.join(args.dest_dir, 'valid_{}.pkl'.format(args.threshold)) |
85 | | - logging.info('Saving validation set to {}'.format(valid_pkl_path)) |
86 | | - with open(valid_pkl_path, 'wb') as f: |
| 94 | + valid_pkl_path = os.path.join( |
| 95 | + args.dest_dir, "valid_{}.pkl".format(args.threshold) |
| 96 | + ) |
| 97 | + logging.info("Saving validation set to {}".format(valid_pkl_path)) |
| 98 | + with open(valid_pkl_path, "wb") as f: |
87 | 99 | pickle.dump(valid, f) |
88 | | - |
| 100 | + |
89 | 101 | # test |
90 | 102 | logging.info('Processing testing set from {}'.format(config['test_path'])) |
91 | 103 | test = preprocessor.get_dataset(test_data, args.n_workers) |
92 | 104 | test_pkl_path = os.path.join(args.dest_dir, 'test_{}.pkl'.format(args.threshold)) |
93 | 105 | logging.info('Saving testing set to {}'.format(test_pkl_path)) |
94 | 106 | with open(test_pkl_path, 'wb') as f: |
95 | 107 | pickle.dump(test, f) |
96 | | - |
| 108 | + |
97 | 109 | # long test |
98 | 110 | logging.info('Processing long corpus testing set from {}'.format(config['long_test_path'])) |
99 | 111 | long_test = preprocessor.get_dataset(long_test_data, args.n_workers) |
100 | 112 | long_test_pkl_path = os.path.join(args.dest_dir, 'long_test_{}.pkl'.format(args.threshold)) |
101 | 113 | logging.info('Saving long corpus testing set to {}'.format(long_test_pkl_path)) |
102 | 114 | with open(long_test_pkl_path, 'wb') as f: |
103 | 115 | pickle.dump(long_test, f) |
104 | | - |
105 | | - # TEST |
| 116 | + |
| 117 | +# TEST |
106 | 118 | # dataloader = DataLoader(long_test, |
107 | 119 | # collate_fn=long_test.collate_fn, |
108 | 120 | # batch_size=4, |
109 | 121 | # shuffle=False, num_workers=args.n_workers) |
110 | | - |
| 122 | + |
111 | 123 | # for data in dataloader: |
| 124 | +# import pdb |
112 | 125 | # pdb.set_trace() |
113 | 126 | # print(data) |
114 | 127 |
|
| 128 | + |
115 | 129 | def _parse_args(): |
116 | 130 | parser = argparse.ArgumentParser( |
117 | | - description="Preprocess and generate preprocessed pickle.") |
118 | | - parser.add_argument('dest_dir', type=str, |
119 | | - help='[input] Path to the directory that .') |
120 | | - parser.add_argument('--n_workers', type=int, default=8) |
121 | | - parser.add_argument('--threshold', type=int, default=0) |
| 131 | + description="Preprocess and generate preprocessed pickle." |
| 132 | + ) |
| 133 | + parser.add_argument( |
| 134 | + "dest_dir", type=str, help="[input] Path to the directory that ." |
| 135 | + ) |
| 136 | + parser.add_argument("--n_workers", type=int, default=8) |
| 137 | + parser.add_argument("--threshold", type=int, default=0) |
122 | 138 | args = parser.parse_args() |
123 | 139 | return args |
124 | 140 |
|
125 | 141 |
|
126 | | -if __name__ == '__main__': |
| 142 | +if __name__ == "__main__": |
127 | 143 | logging.basicConfig( |
128 | | - format='%(asctime)s | %(levelname)s | %(name)s: %(message)s', |
129 | | - level=logging.INFO, datefmt='%m-%d %H:%M:%S' |
| 144 | + format="%(asctime)s | %(levelname)s | %(name)s: %(message)s", |
| 145 | + level=logging.INFO, |
| 146 | + datefmt="%m-%d %H:%M:%S", |
130 | 147 | ) |
131 | 148 | args = _parse_args() |
132 | 149 | try: |
|
0 commit comments