Skip to content
Merged
4 changes: 2 additions & 2 deletions go/codegen/alps/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func Train(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
var program bytes.Buffer
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
"intArrayToJSONString": codegen.MarshalToJSONString,
"attrToPythonValue": tensorflow.AttrToPythonValue,
"DTypeToString": tensorflow.DTypeToString,
"attrToPythonValue": codegen.AttrToPythonValue,
"DTypeToString": codegen.DTypeToString,
}).Parse(templateTrain))
if err := trainTemplate.Execute(&program, filler); err != nil {
return "", err
Expand Down
79 changes: 79 additions & 0 deletions go/codegen/codegen_python_values.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2020 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package codegen

import (
"fmt"
"strings"

"sqlflow.org/sqlflow/go/ir"
)

// DTypeToString returns string value of dtype
func DTypeToString(dt int) string {
switch dt {
case ir.Float:
return "float32"
case ir.Int:
return "int64"
case ir.String:
return "string"
default:
return ""
}
}

// AttrToPythonValue format the WITH attributes to corresponding Python code.
func AttrToPythonValue(attr interface{}) string {
switch attr.(type) {
case bool:
return strings.Title(fmt.Sprintf("%v", attr.(bool)))
case int:
return fmt.Sprintf("%d", attr.(int))
case int64:
return fmt.Sprintf("%d", attr.(int64))
case float32:
return fmt.Sprintf("%f", attr.(float32))
case float64: // FIXME(typhoonzero): may never use
return fmt.Sprintf("%f", attr.(float64))
case []int:
intArrayAttrStr, _ := MarshalToJSONString(attr.([]int))
return intArrayAttrStr
case []string:
l := attr.([]string)
if len(l) == 0 {
return "[]"
}
stringListStr, _ := MarshalToJSONString(l)
return stringListStr
case []interface{}:
tmplist := attr.([]interface{})
if len(tmplist) > 0 {
if _, ok := tmplist[0].(int); ok {
intlist := []int{}
for _, v := range tmplist {
intlist = append(intlist, v.(int))
}
intlistStr, _ := MarshalToJSONString(intlist)
return intlistStr
}
}
// TODO(typhoonzero): support []float etc.
return "[]"
case string:
return fmt.Sprintf("\"%s\"", attr.(string))
default:
return ""
}
}
30 changes: 30 additions & 0 deletions go/codegen/codegen_python_values_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2020 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package codegen

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestAttrToPython(t *testing.T) {
a := assert.New(t)
l := []string{"a", "b", "c"}
v := AttrToPythonValue(l)
a.Equal("[\"a\",\"b\",\"c\"]", v)
l = []string{}
v = AttrToPythonValue(l)
a.Equal("[]", v)
}
3 changes: 1 addition & 2 deletions go/codegen/experimental/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error)
for _, sql := range sqls {
if sql.IsExtendedSyntax() {
if sql.Train {
// TODO(typhoonzero): use feature derivation at runtime, call GenerateTrainStmt only.
r, err = ir.GenerateTrainStmtWithInferredColumns(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false, false)
r, err = ir.GenerateTrainStmt(sql.SQLFlowSelectStmt)
} else if sql.ShowTrain {
r, err = ir.GenerateShowTrainStmt(sql.SQLFlowSelectStmt)
} else if sql.Explain {
Expand Down
8 changes: 8 additions & 0 deletions go/codegen/experimental/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,18 @@ func TestExperimentalXGBCodegen(t *testing.T) {
if os.Getenv("SQLFLOW_TEST_DB") != "mysql" {
t.Skipf("skip TestExperimentalXGBCodegen of DB type %s", os.Getenv("SQLFLOW_TEST_DB"))
}
// test without COLUMN clause
sql := "SELECT * FROM iris.train TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 LABEL class INTO sqlflow_models.xgb_classification;"
s := &pb.Session{DbConnStr: database.GetTestingMySQLURL()}
_, err := GenerateCodeCouler(sql, s)
if err != nil {
t.Errorf("error %s", err)
}

// test with COLUMN clause
sql = "SELECT * FROM iris.train TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 COLUMN petal_length LABEL class INTO sqlflow_models.xgb_classification;"
_, err = GenerateCodeCouler(sql, s)
if err != nil {
t.Errorf("error %s", err)
}
}
Loading