Skip to content

Commit 2c1c9ba

Browse files
committed
add a prediction script
This adds a `predict.py` script to do inference with the model, predicting sql from questions without attempting to evaluate accuracy (which would require that the questions be labeled). The `train.py` script is tweaked slightly so that the methods within it can be used by `predict.py`. The `annotate_ws.py` script is also tweaked slightly so that it can be run on user data. Two utilities, `add_csv.py` and `add_question.py` are added to give a quick way to set up some new tables and questions. To minimize changes to existing code, I add unlabeled questions with a dummy label, as empty as I could make it while still having working code.
1 parent 8961b4d commit 2c1c9ba

File tree

5 files changed

+248
-4
lines changed

5 files changed

+248
-4
lines changed

add_csv.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python
2+
3+
# Add a CSV file as a table into <split>.db and <split>.tables.jsonl
4+
# Call as:
5+
# python add_csv.py <split> <filename.csv>
6+
# For a CSV file called data.csv, the table will be called table_data in the .db
7+
# file, and will be assigned the id 'data'.
8+
# All columns are treated as text - no attempt is made to sniff the type of value
9+
# stored in the column.
10+
11+
import argparse, csv, json, os
12+
from sqlalchemy import Column, create_engine, MetaData, String, Table
13+
14+
def get_table_name(table_id):
15+
return 'table_{}'.format(table_id)
16+
17+
def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name):
18+
engine = create_engine('sqlite:///{}'.format(sqlite_file_name))
19+
with open(csv_file_name) as f:
20+
metadata = MetaData(bind=engine)
21+
cf = csv.DictReader(f, delimiter=',')
22+
simple_name = dict([(name, 'col%d' % i) for i, name in enumerate(cf.fieldnames)])
23+
table = Table(get_table_name(table_id), metadata,
24+
*(Column(simple_name[name], String())
25+
for name in cf.fieldnames))
26+
table.drop(checkfirst=True)
27+
table.create()
28+
for row in cf:
29+
row = dict((simple_name[name], val) for name, val in row.items())
30+
table.insert().values(**row).execute()
31+
return engine
32+
33+
def csv_to_json(table_id, csv_file_name, json_file_name):
34+
with open(csv_file_name) as f:
35+
cf = csv.DictReader(f, delimiter=',')
36+
record = {}
37+
record['header'] = [(name or 'col{}'.format(i)) for i, name in enumerate(cf.fieldnames)]
38+
record['page_title'] = None
39+
record['types'] = ['text'] * len(cf.fieldnames)
40+
record['id'] = table_id
41+
record['caption'] = None
42+
record['rows'] = [list(row.values()) for row in cf]
43+
record['name'] = get_table_name(table_id)
44+
with open(json_file_name, 'a+') as fout:
45+
json.dump(record, fout)
46+
fout.write('\n')
47+
48+
if __name__ == '__main__':
49+
parser = argparse.ArgumentParser()
50+
parser.add_argument('split')
51+
parser.add_argument('file', metavar='file.csv')
52+
args = parser.parse_args()
53+
table_id = os.path.splitext(os.path.basename(args.file))[0]
54+
csv_to_sqlite(table_id, args.file, '{}.db'.format(args.split))
55+
csv_to_json(table_id, args.file, '{}.tables.jsonl'.format(args.split))
56+
print("Added table with id '{id}' (name '{name}') to {split}.db and {split}.tables.jsonl".format(
57+
id=table_id, name=get_table_name(table_id), split=args.split))
58+

add_question.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/usr/bin/env python
2+
3+
# Add a line of json representing a question into <split>.jsonl
4+
# Call as:
5+
# python add_question.py <split> <table id> <question>
6+
#
7+
# This utility is not intended for use during training. A dummy label is added to the
8+
# question to make it loadable by existing code.
9+
#
10+
# For example, suppose we downloaded this list of us state abbreviations:
11+
# https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/USstateAbbreviations.csv
12+
# Let's rename it as something short, say "abbrev.csv"
13+
# Now we can add it to a split called say "playground":
14+
# python add_csv.py playground abbrev.csv
15+
# And now we can add a question about it to the same split:
16+
# python add_question.py playground abbrev "what state has ansi digits of 11"
17+
# The next step would be to annotate the split:
18+
# python annotate_ws.py --din $PWD --dout $PWD --split playground
19+
# Then we're ready to run prediction on the split with predict.py
20+
21+
import argparse, csv, json
22+
23+
from sqlalchemy import Column, create_engine, Integer, MetaData, String, Table
24+
from sqlalchemy.exc import ArgumentError
25+
from sqlalchemy.ext.declarative import declarative_base
26+
from sqlalchemy.orm import create_session, mapper
27+
28+
def question_to_json(table_id, question, json_file_name):
29+
record = {
30+
'phase': 1,
31+
'table_id': table_id,
32+
'question': question,
33+
'sql': {'sel': 0, 'conds': [], 'agg': 0}
34+
}
35+
with open(json_file_name, 'a+') as fout:
36+
json.dump(record, fout)
37+
fout.write('\n')
38+
39+
if __name__ == '__main__':
40+
parser = argparse.ArgumentParser()
41+
parser.add_argument('split')
42+
parser.add_argument('table_id')
43+
parser.add_argument('question', type=str, nargs='+')
44+
args = parser.parse_args()
45+
json_file_name = '{}.jsonl'.format(args.split)
46+
question_to_json(args.table_id, " ".join(args.question), json_file_name)
47+
print("Added question (with dummy label) to {}".format(json_file_name))

annotate_ws.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def is_valid_example(e):
155155
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
156156
parser.add_argument('--din', default='/Users/wonseok/data/WikiSQL-1.1/data', help='data directory')
157157
parser.add_argument('--dout', default='/Users/wonseok/data/wikisql_tok', help='output directory')
158+
parser.add_argument('--split', default='train,dev,test', help='comma=separated list of splits to process')
158159
args = parser.parse_args()
159160

160161
answer_toy = not True
@@ -164,7 +165,7 @@ def is_valid_example(e):
164165
os.makedirs(args.dout)
165166

166167
# for split in ['train', 'dev', 'test']:
167-
for split in ['train', 'dev', 'test']:
168+
for split in args.split.split(','):
168169
fsplit = os.path.join(args.din, split) + '.jsonl'
169170
ftable = os.path.join(args.din, split) + '.tables.jsonl'
170171
fout = os.path.join(args.dout, split) + '_tok.jsonl'

predict.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#!/usr/bin/env python
2+
3+
# Use existing model to predict sql from tables and questions.
4+
#
5+
# For example, you can get a pretrained model from https://github.com/naver/sqlova/releases:
6+
# https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt
7+
# https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_best.pt
8+
#
9+
# Make sure you also have the following support files (see README for where to get them):
10+
# - bert_config_uncased_*.json
11+
# - pytorch_model_*.bin
12+
# - vocab_uncased_*.txt
13+
#
14+
# Finally, you need some data - some files called:
15+
# - <split>.db
16+
# - <split>.jsonl
17+
# - <split>.tables.jsonl
18+
# - <split>_tok.jsonl # derived using annotate_ws.py
19+
# You can play with the existing train/dev/test splits, or make your own with
20+
# the add_csv.py and add_question.py utilities.
21+
#
22+
# Once you have all that, you are ready to predict, using:
23+
# python predict.py \
24+
# --bert_type_add uL \ # need to match the architecture of the model you are using
25+
# --model_file <path to models>/model_best.pt \
26+
# --bert_model_file <path to models>/model_bert_best.pt \
27+
# --bert_path <path to bert_config/pytorch model/vocab> \
28+
# --result_path <where to place results> \
29+
# --data_path <path to db/jsonl/tables.jsonl> \
30+
# --split <split>
31+
#
32+
# Results will be in a file called results_<split>.jsonl in the result_path.
33+
34+
import argparse, os
35+
from sqlnet.dbengine import DBEngine
36+
from sqlova.utils.utils_wikisql import *
37+
from train import construct_hyper_param, get_models, get_opt
38+
39+
# This is a stripped down version of the test() method in train.py - identical, except:
40+
# - does not attempt to measure accuracy and indeed does not expect the data to be labelled.
41+
# - saves plain text sql queries.
42+
#
43+
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
44+
max_seq_length,
45+
num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4,
46+
path_db=None, dset_name='test'):
47+
48+
model.eval()
49+
model_bert.eval()
50+
51+
engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
52+
results = []
53+
for iB, t in enumerate(data_loader):
54+
nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
55+
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
56+
g_wvi_corenlp = get_g_wvi_corenlp(t)
57+
wemb_n, wemb_h, l_n, l_hpu, l_hs, \
58+
nlu_tt, t_to_tt_idx, tt_to_t_idx \
59+
= get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
60+
num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
61+
if not EG:
62+
# No Execution guided decoding
63+
s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)
64+
pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, )
65+
pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
66+
pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu)
67+
else:
68+
# Execution guided decoding
69+
prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
70+
l_hs, engine, tb,
71+
nlu_t, nlu_tt,
72+
tt_to_t_idx, nlu,
73+
beam_size=beam_size)
74+
# sort and generate
75+
pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
76+
# Following variables are just for consistency with no-EG case.
77+
pr_wvi = None # not used
78+
pr_wv_str=None
79+
pr_wv_str_wp=None
80+
81+
pr_sql_q = generate_sql_q(pr_sql_i, tb)
82+
83+
for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)):
84+
results1 = {}
85+
results1["query"] = pr_sql_i1
86+
results1["table_id"] = tb[b]["id"]
87+
results1["nlu"] = nlu[b]
88+
results1["sql"] = pr_sql_q1
89+
results.append(results1)
90+
91+
return results
92+
93+
## Set up hyper parameters and paths
94+
parser = argparse.ArgumentParser()
95+
parser.add_argument("--model_file", required=True, help='model file to use (e.g. model_best.pt)')
96+
parser.add_argument("--bert_model_file", required=True, help='bert model file to use (e.g. model_bert_best.pt)')
97+
parser.add_argument("--bert_path", required=True, help='path to bert files (bert_config*.json etc)')
98+
parser.add_argument("--data_path", required=True, help='path to *.jsonl and *.db files')
99+
parser.add_argument("--split", required=True, help='prefix of jsonl and db files (e.g. dev)')
100+
parser.add_argument("--result_path", required=True, help='directory in which to place results')
101+
args = construct_hyper_param(parser)
102+
103+
BERT_PT_PATH = args.bert_path
104+
path_save_for_evaluation = args.result_path
105+
106+
# Load pre-trained models
107+
path_model_bert = args.bert_model_file
108+
path_model = args.model_file
109+
model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model)
110+
opt, opt_bert = get_opt(model, model_bert, args)
111+
112+
# Load data
113+
dev_data, dev_table = load_wikisql_data(args.data_path, mode=args.split, toy_model=args.toy_model, toy_size=args.toy_size, no_hs_tok=True)
114+
dev_loader = torch.utils.data.DataLoader(
115+
batch_size=args.bS,
116+
dataset=dev_data,
117+
shuffle=False,
118+
num_workers=1,
119+
collate_fn=lambda x: x # now dictionary values are not merged!
120+
)
121+
122+
# Run prediction
123+
with torch.no_grad():
124+
results = predict(dev_loader,
125+
dev_table,
126+
model,
127+
model_bert,
128+
bert_config,
129+
tokenizer,
130+
args.max_seq_length,
131+
args.num_target_layers,
132+
detail=False,
133+
path_db=args.data_path,
134+
st_pos=0,
135+
dset_name=args.split, EG=args.EG)
136+
137+
# Save results
138+
save_for_evaluation(path_save_for_evaluation, results, args.split)

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def get_bert(BERT_PT_PATH, bert_type, do_lower_case, no_pretraining):
127127

128128
return model_bert, tokenizer, bert_config
129129

130-
def get_opt(model, model_bert, fine_tune):
131-
if fine_tune:
130+
def get_opt(model, model_bert, args):
131+
if args.fine_tune:
132132
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
133133
lr=args.lr, weight_decay=0)
134134

@@ -582,7 +582,7 @@ def print_result(epoch, acc, dname):
582582
# model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model)
583583

584584
## 5. Get optimizers
585-
opt, opt_bert = get_opt(model, model_bert, args.fine_tune)
585+
opt, opt_bert = get_opt(model, model_bert, args)
586586

587587
## 6. Train
588588
acc_lx_t_best = -1

0 commit comments

Comments
 (0)