|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Use existing model to predict sql from tables and questions. |
| 4 | +# |
| 5 | +# For example, you can get a pretrained model from https://github.com/naver/sqlova/releases: |
| 6 | +# https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt |
| 7 | +# https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_best.pt |
| 8 | +# |
| 9 | +# Make sure you also have the following support files (see README for where to get them): |
| 10 | +# - bert_config_uncased_*.json |
| 11 | +# - pytorch_model_*.bin |
| 12 | +# - vocab_uncased_*.txt |
| 13 | +# |
| 14 | +# Finally, you need some data - some files called: |
| 15 | +# - <split>.db |
| 16 | +# - <split>.jsonl |
| 17 | +# - <split>.tables.jsonl |
| 18 | +# - <split>_tok.jsonl # derived using annotate_ws.py |
| 19 | +# You can play with the existing train/dev/test splits, or make your own with |
| 20 | +# the add_csv.py and add_question.py utilities. |
| 21 | +# |
| 22 | +# Once you have all that, you are ready to predict, using: |
| 23 | +# python predict.py \ |
| 24 | +# --bert_type_add uL \ # need to match the architecture of the model you are using |
| 25 | +# --model_file <path to models>/model_best.pt \ |
| 26 | +# --bert_model_file <path to models>/model_bert_best.pt \ |
| 27 | +# --bert_path <path to bert_config/pytorch model/vocab> \ |
| 28 | +# --result_path <where to place results> \ |
| 29 | +# --data_path <path to db/jsonl/tables.jsonl> \ |
| 30 | +# --split <split> |
| 31 | +# |
| 32 | +# Results will be in a file called results_<split>.jsonl in the result_path. |
| 33 | + |
| 34 | +import argparse, os |
| 35 | +from sqlnet.dbengine import DBEngine |
| 36 | +from sqlova.utils.utils_wikisql import * |
| 37 | +from train import construct_hyper_param, get_models, get_opt |
| 38 | + |
| 39 | +# This is a stripped down version of the test() method in train.py - identical, except: |
| 40 | +# - does not attempt to measure accuracy and indeed does not expect the data to be labelled. |
| 41 | +# - saves plain text sql queries. |
| 42 | +# |
| 43 | +def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer, |
| 44 | + max_seq_length, |
| 45 | + num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4, |
| 46 | + path_db=None, dset_name='test'): |
| 47 | + |
| 48 | + model.eval() |
| 49 | + model_bert.eval() |
| 50 | + |
| 51 | + engine = DBEngine(os.path.join(path_db, f"{dset_name}.db")) |
| 52 | + results = [] |
| 53 | + for iB, t in enumerate(data_loader): |
| 54 | + nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True) |
| 55 | + g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i) |
| 56 | + g_wvi_corenlp = get_g_wvi_corenlp(t) |
| 57 | + wemb_n, wemb_h, l_n, l_hpu, l_hs, \ |
| 58 | + nlu_tt, t_to_tt_idx, tt_to_t_idx \ |
| 59 | + = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length, |
| 60 | + num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers) |
| 61 | + if not EG: |
| 62 | + # No Execution guided decoding |
| 63 | + s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs) |
| 64 | + pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, ) |
| 65 | + pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu) |
| 66 | + pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu) |
| 67 | + else: |
| 68 | + # Execution guided decoding |
| 69 | + prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu, |
| 70 | + l_hs, engine, tb, |
| 71 | + nlu_t, nlu_tt, |
| 72 | + tt_to_t_idx, nlu, |
| 73 | + beam_size=beam_size) |
| 74 | + # sort and generate |
| 75 | + pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i) |
| 76 | + # Following variables are just for consistency with no-EG case. |
| 77 | + pr_wvi = None # not used |
| 78 | + pr_wv_str=None |
| 79 | + pr_wv_str_wp=None |
| 80 | + |
| 81 | + pr_sql_q = generate_sql_q(pr_sql_i, tb) |
| 82 | + |
| 83 | + for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)): |
| 84 | + results1 = {} |
| 85 | + results1["query"] = pr_sql_i1 |
| 86 | + results1["table_id"] = tb[b]["id"] |
| 87 | + results1["nlu"] = nlu[b] |
| 88 | + results1["sql"] = pr_sql_q1 |
| 89 | + results.append(results1) |
| 90 | + |
| 91 | + return results |
| 92 | + |
| 93 | +## Set up hyper parameters and paths |
| 94 | +parser = argparse.ArgumentParser() |
| 95 | +parser.add_argument("--model_file", required=True, help='model file to use (e.g. model_best.pt)') |
| 96 | +parser.add_argument("--bert_model_file", required=True, help='bert model file to use (e.g. model_bert_best.pt)') |
| 97 | +parser.add_argument("--bert_path", required=True, help='path to bert files (bert_config*.json etc)') |
| 98 | +parser.add_argument("--data_path", required=True, help='path to *.jsonl and *.db files') |
| 99 | +parser.add_argument("--split", required=True, help='prefix of jsonl and db files (e.g. dev)') |
| 100 | +parser.add_argument("--result_path", required=True, help='directory in which to place results') |
| 101 | +args = construct_hyper_param(parser) |
| 102 | + |
| 103 | +BERT_PT_PATH = args.bert_path |
| 104 | +path_save_for_evaluation = args.result_path |
| 105 | + |
| 106 | +# Load pre-trained models |
| 107 | +path_model_bert = args.bert_model_file |
| 108 | +path_model = args.model_file |
| 109 | +model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model) |
| 110 | +opt, opt_bert = get_opt(model, model_bert, args) |
| 111 | + |
| 112 | +# Load data |
| 113 | +dev_data, dev_table = load_wikisql_data(args.data_path, mode=args.split, toy_model=args.toy_model, toy_size=args.toy_size, no_hs_tok=True) |
| 114 | +dev_loader = torch.utils.data.DataLoader( |
| 115 | + batch_size=args.bS, |
| 116 | + dataset=dev_data, |
| 117 | + shuffle=False, |
| 118 | + num_workers=1, |
| 119 | + collate_fn=lambda x: x # now dictionary values are not merged! |
| 120 | +) |
| 121 | + |
| 122 | +# Run prediction |
| 123 | +with torch.no_grad(): |
| 124 | + results = predict(dev_loader, |
| 125 | + dev_table, |
| 126 | + model, |
| 127 | + model_bert, |
| 128 | + bert_config, |
| 129 | + tokenizer, |
| 130 | + args.max_seq_length, |
| 131 | + args.num_target_layers, |
| 132 | + detail=False, |
| 133 | + path_db=args.data_path, |
| 134 | + st_pos=0, |
| 135 | + dset_name=args.split, EG=args.EG) |
| 136 | + |
| 137 | +# Save results |
| 138 | +save_for_evaluation(path_save_for_evaluation, results, args.split) |
0 commit comments