Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions go/cmd/sqlflowserver/e2e_common_cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -1115,3 +1115,25 @@ USING %s;`, caseTrainTable, predictTableName, predWith, caseInto)
}
}
}

func casePassSelectedColumnsToPredResult(t *testing.T) {
a := assert.New(t)
sql := fmt.Sprintf(`
SELECT sepal_length, sepal_width, petal_length
FROM %s
TO TRAIN sqlflow_models.OneClassSVM
WITH model.kernel = "rbf"
INTO %s;

SELECT sepal_length, sepal_width, petal_length, petal_width
FROM %s
TO PREDICT %s.label
USING %s;`, caseTrainTable, caseInto, caseTrainTable, casePredictTable, caseInto)
_, _, _, err := connectAndRunSQL(sql)
a.NoError(err)
selectPredSQL := fmt.Sprintf(`SELECT * FROM %s LIMIT 1;`, casePredictTable)
headers, _, _, err := connectAndRunSQL(selectPredSQL)
expectedHeaders := []string{"sepal_length",
"sepal_width", "petal_length", "petal_width", "label"}
a.True(reflect.DeepEqual(headers, expectedHeaders))
}
1 change: 1 addition & 0 deletions go/cmd/sqlflowserver/e2e_mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func TestEnd2EndMySQL(t *testing.T) {
t.Run("CaseSQLByPassLeftJoin", CaseSQLByPassLeftJoin)
t.Run("CaseTrainRegression", caseTrainRegression)
t.Run("CaseScoreCard", caseScoreCard)
t.Run("CasePassSelectedColumnsToPredResult", casePassSelectedColumnsToPredResult)

t.Run("CaseOneClassSVMModel", func(t *testing.T) {
caseOneClassSVMModel(t, nil)
Expand Down
34 changes: 28 additions & 6 deletions python/runtime/tensorflow/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ def eval_input_fn(batch_size, cache=False):
dataset = dataset.cache()
return dataset

def to_feature_sample(row, selected_cols):
features = {}
for name in feature_column_names:
row_val = row[selected_cols.index(name)]
if feature_metas[name].get("delimiter_kv", "") != "":
# kv list that should be parsed to two features.
if feature_metas[name]["is_sparse"]:
features[name] = tf.SparseTensor(
row_val[0], tf.ones_like(tf.reshape(row_val[0], [-1])),
row_val[2])
features["_".join([name,
"weight"])] = tf.SparseTensor(*row_val)
else:
raise ValueError(
"not supported DENSE column with key:value"
"list format.")
else:
if feature_metas[name]["is_sparse"]:
features[name] = tf.SparseTensor(*row_val)
else:
features[name] = tf.constant(([row_val], ))
return features

if not hasattr(classifier, 'sqlflow_predict_one'):
# NOTE: load_weights should be called by keras models only.
# NOTE: always use batch_size=1 when predicting to get the pairs of
Expand All @@ -108,8 +131,10 @@ def eval_input_fn(batch_size, cache=False):
column_names.append(result_col_name)

column_names.extend(extra_result_cols)

with db.buffered_db_writer(conn, result_table, column_names, 100) as w:
for features in pred_dataset:
for row, _ in predict_generator():
features = to_feature_sample(row, column_names)
if hasattr(classifier, 'sqlflow_predict_one'):
result = classifier.sqlflow_predict_one(features)
else:
Expand Down Expand Up @@ -143,15 +168,12 @@ def eval_input_fn(batch_size, cache=False):
result = result[0].argmax(axis=-1)
else:
result = result[0] # multiple regression result
row = []
for idx, name in enumerate(feature_column_names):
val = features[name].numpy()[0][0]
row.append(str(val))

row.append(encode_pred_result(result))
if extra_pred_outputs is not None:
row.extend([encode_pred_result(p) for p in extra_pred_outputs])

if train_label_index != -1 and len(row) > train_label_index:
del row[train_label_index]
w.write(row)
del pred_dataset

Expand Down