Skip to content

Commit f076f47

Browse files
committed
reformat document
1 parent ea60313 commit f076f47

File tree

1 file changed

+10
-39
lines changed

1 file changed

+10
-39
lines changed

src/train.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,36 @@
1-
import argparse
2-
import logging
31
import os
4-
import pdb
5-
import pickle
62
import sys
7-
import traceback
83
import json
4+
import logging
5+
import traceback
6+
7+
import pickle
8+
import argparse
9+
10+
from metrics import Accuracy
911
from callbacks import ModelCheckpoint, MetricsLogger
10-
from metrics import *
11-
from lr_finder import LRFinder
1212

1313

1414
def main(args):
1515
config_path = os.path.join(args.model_dir, 'config.json')
1616
with open(config_path) as f:
1717
config = json.load(f)
18-
18+
1919
logging.info('loading word dictionary...')
2020
with open(config['words_dict'], 'rb') as f:
2121
words_dict = pickle.load(f)
2222

2323
logging.info('loading train data...')
2424
with open(config['train'], 'rb') as f:
2525
train = pickle.load(f)
26-
26+
2727
logging.info('loading validation data...')
2828
with open(config['model_parameters']['valid'], 'rb') as f:
2929
valid = pickle.load(f)
3030
config['model_parameters']['valid'] = valid
31-
31+
3232
if args.lr_finder:
33-
# logging.info('creating model!')
34-
35-
# vocab_size = len(words_dict)
36-
# hidden_size = config['model_parameters']['hidden_size']
37-
# embedding_size = config['model_parameters']['embedding_size']
38-
# use_same_embedding = config['model_parameters']['use_same_embedding']
39-
40-
# encoder = EncoderRNN(vocab_size, hidden_size, embedding_size)
41-
# decoder = DecoderRNN(hidden_size, vocab_size, embedding_size, use_same_embedding)
42-
43-
# criterion = nn.CrossEntropyLoss()
44-
45-
# if config['model_parameters']['optimizer'] == 'Adam':
46-
# optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
47-
# lr=1e-7,
48-
# weight_decay=1e-2)
49-
# elif config['model_parameters']['optimizer'] == 'SGD':
50-
# optimizer = torch.optim.SGD(list(encoder.parameters()) + list(decoder.parameters()),
51-
# lr=1e-7,
52-
# momentum=0.9,
53-
# weight_decay=1e-2)
54-
55-
# lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
56-
# lr_finder.range_test(trainloader, end_lr=100, num_iter=100, step_mode="exp")
57-
58-
# # use cuda
59-
# self.encoder = self.encoder.to(self.device)
60-
# self.decoder = self.decoder.to(self.device)
6133
pass
62-
6334
else:
6435
if config['arch'] == 'Predictor':
6536
from predictor import Predictor

0 commit comments

Comments
 (0)