Skip to content

Commit 53ce4d6

Browse files
committed
reformat document, add test and remove test long
1 parent f25e701 commit 53ce4d6

File tree

1 file changed

+79
-62
lines changed

1 file changed

+79
-62
lines changed

src/make_dataset.py

Lines changed: 79 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,149 @@
1-
import argparse
2-
import logging
31
import os
42
import pdb
53
import sys
6-
import traceback
7-
import pickle
84
import json
9-
import pandas as pd
5+
import logging
6+
import traceback
107
import collections
8+
9+
import argparse
10+
import pickle
1111
from tqdm import tqdm
12-
from embedding import Embedding
12+
1313
from preprocessor import Preprocessor
14-
from torch.utils.data import DataLoader
14+
1515

1616
def main(args):
17-
config_path = os.path.join(args.dest_dir, 'config.json')
18-
logging.info('loading configuration from {}'.format(config_path))
17+
config_path = os.path.join(args.dest_dir, "config.json")
18+
logging.info("loading configuration from {}".format(config_path))
1919
with open(config_path) as f:
2020
config = json.load(f)
21-
21+
2222
preprocessor = Preprocessor(None)
23-
24-
logging.info('loading training data from {}'.format(config['train_path']))
25-
with open(config['train_path'], 'r') as f:
23+
24+
logging.info("loading training data from {}".format(config["train_path"]))
25+
with open(config["train_path"], "r") as f:
2626
train_data = f.readlines()
27-
28-
logging.info('loading validation data from {}'.format(config['valid_path']))
29-
with open(config['valid_path'], 'r') as f:
27+
28+
logging.info("loading validation data from {}".format(config["valid_path"]))
29+
with open(config["valid_path"], "r") as f:
3030
valid_data = f.readlines()
31-
32-
logging.info('loading testing data from {}'.format(config['test_path']))
33-
with open(config['test_path'], 'r') as f:
31+
32+
logging.info("loading testing data from {}".format(config["test_path"]))
33+
with open(config["test_path"], "r") as f:
3434
test_data = f.readlines()
35-
36-
logging.info('loading long corpus testing data from {}'.format(config['long_test_path']))
37-
with open(config['long_test_path'], 'r') as f:
35+
36+
logging.info(
37+
"loading long corpus testing data from {}".format(config["long_test_path"])
38+
)
39+
with open(config["long_test_path"], "r") as f:
3840
long_test_data = f.readlines()
39-
41+
4042
# collect words appear in the data
41-
logging.info('collecting words from training set...')
43+
logging.info("collecting words from training set...")
4244
words = collections.Counter()
4345
for data in tqdm(train_data, total=len(train_data)):
4446
words.update(data.strip().split())
45-
logging.info('{} words collected'.format(len(words)))
46-
47+
logging.info("{} words collected".format(len(words)))
48+
4749
# build sorted vocab dictionary (for adaptive_softmax loss later)
4850
word_dict = {}
4951
counter = 4
50-
word_dict['<PAD>'] = 0
51-
word_dict['<UNK>'] = 1
52-
word_dict['<SOS>'] = 2
53-
word_dict['<EOS>'] = 3
52+
word_dict["<PAD>"] = 0
53+
word_dict["<UNK>"] = 1
54+
word_dict["<SOS>"] = 2
55+
word_dict["<EOS>"] = 3
5456
for word in tqdm(words.most_common()):
5557
if word[1] > args.threshold:
5658
word_dict[word[0]] = counter
5759
counter += 1
58-
59-
logging.info('{} words saved'.format(counter))
60-
61-
vocab_path = '_{}.pkl'.format(args.threshold).join(config['vocab_path'].split('.pkl'))
62-
word_set_path = '_{}.pkl'.format(args.threshold).join(config['word_set_path'].split('.pkl'))
63-
64-
with open(vocab_path, 'wb') as fout:
60+
61+
logging.info("{} words saved".format(counter))
62+
63+
vocab_path = "_{}.pkl".format(args.threshold).join(
64+
config["vocab_path"].split(".pkl")
65+
)
66+
word_set_path = "_{}.pkl".format(args.threshold).join(
67+
config["word_set_path"].split(".pkl")
68+
)
69+
70+
with open(vocab_path, "wb") as fout:
6571
pickle.dump(word_dict, fout)
66-
with open(word_set_path, 'wb') as fout:
72+
with open(word_set_path, "wb") as fout:
6773
pickle.dump(words, fout)
68-
logging.info('Word frequency and vocab saved in {}, {}'.format(word_set_path, vocab_path))
74+
logging.info(
75+
"Word frequency and vocab saved in {}, {}".format(word_set_path, vocab_path)
76+
)
6977

7078
# update word dictionary used by preprocessor
7179
preprocessor.words_dict = word_dict
72-
80+
7381
# train
74-
logging.info('Processing training set from {}'.format(config['train_path']))
82+
logging.info("Processing training set from {}".format(config["train_path"]))
7583
train = preprocessor.get_dataset(train_data, args.n_workers)
76-
train_pkl_path = os.path.join(args.dest_dir, 'train_{}.pkl'.format(args.threshold))
77-
logging.info('Saving training set to {}'.format(train_pkl_path))
78-
with open(train_pkl_path, 'wb') as f:
84+
train_pkl_path = os.path.join(
85+
args.dest_dir, "train_{}.pkl".format(args.threshold)
86+
)
87+
logging.info("Saving training set to {}".format(train_pkl_path))
88+
with open(train_pkl_path, "wb") as f:
7989
pickle.dump(train, f)
80-
90+
8191
# valid
82-
logging.info('Processing validation set from {}'.format(config['valid_path']))
92+
logging.info("Processing validation set from {}".format(config["valid_path"]))
8393
valid = preprocessor.get_dataset(valid_data, args.n_workers)
84-
valid_pkl_path = os.path.join(args.dest_dir, 'valid_{}.pkl'.format(args.threshold))
85-
logging.info('Saving validation set to {}'.format(valid_pkl_path))
86-
with open(valid_pkl_path, 'wb') as f:
94+
valid_pkl_path = os.path.join(
95+
args.dest_dir, "valid_{}.pkl".format(args.threshold)
96+
)
97+
logging.info("Saving validation set to {}".format(valid_pkl_path))
98+
with open(valid_pkl_path, "wb") as f:
8799
pickle.dump(valid, f)
88-
100+
89101
# test
90102
logging.info('Processing testing set from {}'.format(config['test_path']))
91103
test = preprocessor.get_dataset(test_data, args.n_workers)
92104
test_pkl_path = os.path.join(args.dest_dir, 'test_{}.pkl'.format(args.threshold))
93105
logging.info('Saving testing set to {}'.format(test_pkl_path))
94106
with open(test_pkl_path, 'wb') as f:
95107
pickle.dump(test, f)
96-
108+
97109
# long test
98110
logging.info('Processing long corpus testing set from {}'.format(config['long_test_path']))
99111
long_test = preprocessor.get_dataset(long_test_data, args.n_workers)
100112
long_test_pkl_path = os.path.join(args.dest_dir, 'long_test_{}.pkl'.format(args.threshold))
101113
logging.info('Saving long corpus testing set to {}'.format(long_test_pkl_path))
102114
with open(long_test_pkl_path, 'wb') as f:
103115
pickle.dump(long_test, f)
104-
105-
# TEST
116+
117+
# TEST
106118
# dataloader = DataLoader(long_test,
107119
# collate_fn=long_test.collate_fn,
108120
# batch_size=4,
109121
# shuffle=False, num_workers=args.n_workers)
110-
122+
111123
# for data in dataloader:
124+
# import pdb
112125
# pdb.set_trace()
113126
# print(data)
114127

128+
115129
def _parse_args():
116130
parser = argparse.ArgumentParser(
117-
description="Preprocess and generate preprocessed pickle.")
118-
parser.add_argument('dest_dir', type=str,
119-
help='[input] Path to the directory that .')
120-
parser.add_argument('--n_workers', type=int, default=8)
121-
parser.add_argument('--threshold', type=int, default=0)
131+
description="Preprocess and generate preprocessed pickle."
132+
)
133+
parser.add_argument(
134+
"dest_dir", type=str, help="[input] Path to the directory that ."
135+
)
136+
parser.add_argument("--n_workers", type=int, default=8)
137+
parser.add_argument("--threshold", type=int, default=0)
122138
args = parser.parse_args()
123139
return args
124140

125141

126-
if __name__ == '__main__':
142+
if __name__ == "__main__":
127143
logging.basicConfig(
128-
format='%(asctime)s | %(levelname)s | %(name)s: %(message)s',
129-
level=logging.INFO, datefmt='%m-%d %H:%M:%S'
144+
format="%(asctime)s | %(levelname)s | %(name)s: %(message)s",
145+
level=logging.INFO,
146+
datefmt="%m-%d %H:%M:%S",
130147
)
131148
args = _parse_args()
132149
try:

0 commit comments

Comments
 (0)