|
1 | | -import argparse |
2 | | -import logging |
3 | 1 | import os |
4 | | -import pdb |
5 | | -import pickle |
6 | 2 | import sys |
7 | | -import traceback |
8 | 3 | import json |
| 4 | +import logging |
| 5 | +import traceback |
| 6 | + |
| 7 | +import pickle |
| 8 | +import argparse |
| 9 | + |
| 10 | +from metrics import Accuracy |
9 | 11 | from callbacks import ModelCheckpoint, MetricsLogger |
10 | | -from metrics import * |
11 | | -from lr_finder import LRFinder |
12 | 12 |
|
13 | 13 |
|
14 | 14 | def main(args): |
15 | 15 | config_path = os.path.join(args.model_dir, 'config.json') |
16 | 16 | with open(config_path) as f: |
17 | 17 | config = json.load(f) |
18 | | - |
| 18 | + |
19 | 19 | logging.info('loading word dictionary...') |
20 | 20 | with open(config['words_dict'], 'rb') as f: |
21 | 21 | words_dict = pickle.load(f) |
22 | 22 |
|
23 | 23 | logging.info('loading train data...') |
24 | 24 | with open(config['train'], 'rb') as f: |
25 | 25 | train = pickle.load(f) |
26 | | - |
| 26 | + |
27 | 27 | logging.info('loading validation data...') |
28 | 28 | with open(config['model_parameters']['valid'], 'rb') as f: |
29 | 29 | valid = pickle.load(f) |
30 | 30 | config['model_parameters']['valid'] = valid |
31 | | - |
| 31 | + |
32 | 32 | if args.lr_finder: |
33 | | -# logging.info('creating model!') |
34 | | - |
35 | | -# vocab_size = len(words_dict) |
36 | | -# hidden_size = config['model_parameters']['hidden_size'] |
37 | | -# embedding_size = config['model_parameters']['embedding_size'] |
38 | | -# use_same_embedding = config['model_parameters']['use_same_embedding'] |
39 | | - |
40 | | -# encoder = EncoderRNN(vocab_size, hidden_size, embedding_size) |
41 | | -# decoder = DecoderRNN(hidden_size, vocab_size, embedding_size, use_same_embedding) |
42 | | - |
43 | | -# criterion = nn.CrossEntropyLoss() |
44 | | - |
45 | | -# if config['model_parameters']['optimizer'] == 'Adam': |
46 | | -# optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), |
47 | | -# lr=1e-7, |
48 | | -# weight_decay=1e-2) |
49 | | -# elif config['model_parameters']['optimizer'] == 'SGD': |
50 | | -# optimizer = torch.optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), |
51 | | -# lr=1e-7, |
52 | | -# momentum=0.9, |
53 | | -# weight_decay=1e-2) |
54 | | - |
55 | | -# lr_finder = LRFinder(model, optimizer, criterion, device="cuda") |
56 | | -# lr_finder.range_test(trainloader, end_lr=100, num_iter=100, step_mode="exp") |
57 | | - |
58 | | -# # use cuda |
59 | | -# self.encoder = self.encoder.to(self.device) |
60 | | -# self.decoder = self.decoder.to(self.device) |
61 | 33 | pass |
62 | | - |
63 | 34 | else: |
64 | 35 | if config['arch'] == 'Predictor': |
65 | 36 | from predictor import Predictor |
|
0 commit comments