Skip to content

Commit 3e4b8f2

Browse files
committed
remove unused get_opt from predict.py; revert unneeded change to train.py
1 parent 7e24ae1 commit 3e4b8f2

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

predict.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import argparse, os
3434
from sqlnet.dbengine import DBEngine
3535
from sqlova.utils.utils_wikisql import *
36-
from train import construct_hyper_param, get_models, get_opt
36+
from train import construct_hyper_param, get_models
3737

3838
# This is a stripped down version of the test() method in train.py - identical, except:
3939
# - does not attempt to measure accuracy and indeed does not expect the data to be labelled.
@@ -107,7 +107,6 @@ def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
107107
path_model = args.model_file
108108
args.no_pretraining = True # counterintuitive, but avoids loading unused models
109109
model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model)
110-
opt, opt_bert = get_opt(model, model_bert, args)
111110

112111
# Load data
113112
dev_data, dev_table = load_wikisql_data(args.data_path, mode=args.split, toy_model=args.toy_model, toy_size=args.toy_size, no_hs_tok=True)

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def get_bert(BERT_PT_PATH, bert_type, do_lower_case, no_pretraining):
127127

128128
return model_bert, tokenizer, bert_config
129129

130-
def get_opt(model, model_bert, args):
131-
if args.fine_tune:
130+
def get_opt(model, model_bert, fine_tune):
131+
if fine_tune:
132132
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
133133
lr=args.lr, weight_decay=0)
134134

@@ -582,7 +582,7 @@ def print_result(epoch, acc, dname):
582582
# model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model)
583583

584584
## 5. Get optimizers
585-
opt, opt_bert = get_opt(model, model_bert, args)
585+
opt, opt_bert = get_opt(model, model_bert, args.fine_tune)
586586

587587
## 6. Train
588588
acc_lx_t_best = -1

0 commit comments

Comments
 (0)