Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import (
"path"
"path/filepath"
"regexp"
"sqlflow.org/sqlflow/go/codegen/experimental"
"strings"
"sync"

"sqlflow.org/sqlflow/go/codegen/experimental"

"sqlflow.org/sqlflow/go/verifier"

"sqlflow.org/sqlflow/go/codegen/optimize"
Expand Down
3 changes: 2 additions & 1 deletion go/ir/ir_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ func GenerateTrainStmt(slct *parser.SQLFlowSelectStmt) (*TrainStmt, error) {
}
label := &NumericColumn{
FieldDesc: &FieldDesc{
Name: tc.Label,
Name: tc.Label,
Shape: []int{1},
}}

vslct, _ := parseValidationSelect(attrList)
Expand Down
1 change: 0 additions & 1 deletion go/sql/executor_ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ func TestExecutorTrainAndPredictDNN(t *testing.T) {
}

func TestExecutorTrainAndPredictClusteringLocalFS(t *testing.T) {
t.Skip("fix random nan loss error then re-enable this test")
a := assert.New(t)
modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models")
a.Nil(e)
Expand Down
13 changes: 8 additions & 5 deletions python/runtime/local/create_result_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def create_predict_table(conn, select, result_table, train_label_desc,
"""
name_and_types = db.selected_columns_and_types(conn, select)
train_label_index = -1
for i, (name, _) in enumerate(name_and_types):
if name == train_label_desc.name:
train_label_index = i
break
if train_label_desc:
for i, (name, _) in enumerate(name_and_types):
if name == train_label_desc.name:
train_label_index = i
break

if train_label_index >= 0:
del name_and_types[train_label_index]
Expand All @@ -45,10 +46,12 @@ def create_predict_table(conn, select, result_table, train_label_desc,
column_strs.append("%s %s" %
(name, db.to_db_field_type(conn.driver, typ)))

if train_label_desc.format == DataFormat.PLAIN:
if train_label_desc and train_label_desc.format == DataFormat.PLAIN:
train_label_field_type = DataType.to_db_field_type(
conn.driver, train_label_desc.dtype)
else:
# if no train lable description is provided (clustering),
# we treat the column type as string
train_label_field_type = DataType.to_db_field_type(
conn.driver, DataType.STRING)

Expand Down
5 changes: 3 additions & 2 deletions python/runtime/local/tensorflow_submitter/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def pred(datasource, select, result_table, pred_label_name, model):

model_params = model.get_meta("attributes")
train_fc_map = model.get_meta("features")
train_label_desc = model.get_meta("label").get_field_desc()[0]
train_label_name = train_label_desc.name
label_meta = model.get_meta("label")
train_label_desc = label_meta.get_field_desc()[0] if label_meta else None
train_label_name = train_label_desc.name if train_label_desc else None
estimator_string = model.get_meta("class_name")
save = "model_save"

Expand Down
7 changes: 6 additions & 1 deletion python/runtime/step/tensorflow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ def train_step(original_sql,
feature_column_names = [fd.name for fd in field_descs]
feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True))
for fd in field_descs])
label_meta = fc_label_ir.get_field_desc()[0].to_dict(dtype_to_string=True)

# no label for clustering model
label_meta = None
if fc_label_ir:
label_meta = fc_label_ir.get_field_desc()[0].to_dict(
dtype_to_string=True)

feature_column_names_map = dict()
for target in fc_map_ir:
Expand Down