Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 12 additions & 31 deletions go/cmd/sqlflowserver/e2e_alps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,47 +49,28 @@ func TestEnd2EndMaxComputeALPS(t *testing.T) {
if caseDB == "" {
t.Fatalf("Must set env SQLFLOW_TEST_DB_MAXCOMPUTE_PROJECT when testing ALPS cases (SQLFLOW_submitter=alps)!!")
}
err = prepareTestData(dbConnStr)
if err != nil {
t.Fatalf("prepare test dataset failed: %v", err)
}

go start(modelDir, caCrt, caKey, unitTestPort, false)
server.WaitPortReady(fmt.Sprintf("localhost:%d", unitTestPort), 0)

t.Run("CaseTrainALPS", CaseTrainALPS)
t.Run("CaseTrainALPSFeatureMap", CaseTrainALPSFeatureMap)
t.Run("CaseTrainALPSRemoteModel", CaseTrainALPSRemoteModel)
// TODO(typhoonzero): add this back later
// t.Run("CaseTrainALPSFeatureMap", CaseTrainALPSFeatureMap)
// t.Run("CaseTrainALPSRemoteModel", CaseTrainALPSRemoteModel)
}

// CaseTrainALPS is a case for training models using ALPS with out feature_map table
func CaseTrainALPS(t *testing.T) {
a := assert.New(t)
trainSQL := fmt.Sprintf(`
SELECT deep_id, user_space_stat, user_behavior_stat, space_stat, l
FROM %s.sparse_column_test
LIMIT 100
TO TRAIN DNNClassifier
WITH
model.n_classes = 2,
model.hidden_units = [10, 20],
train.batch_size = 10,
engine.ps_num = 0,
engine.worker_num = 0,
engine.type = local,
validation.table = "%s.sparse_column_test"
COLUMN
SPARSE(deep_id,15033,COMMA,int),
SPARSE(user_space_stat,310,COMMA,int),
SPARSE(user_behavior_stat,511,COMMA,int),
SPARSE(space_stat,418,COMMA,int),
EMBEDDING(CATEGORY_ID(deep_id,15033,COMMA),512,mean),
EMBEDDING(CATEGORY_ID(user_space_stat,310,COMMA),64,mean),
EMBEDDING(CATEGORY_ID(user_behavior_stat,511,COMMA),64,mean),
EMBEDDING(CATEGORY_ID(space_stat,418,COMMA),64,mean)
LABEL l
INTO model_table;
`, caseDB, caseDB)
trainSQL := fmt.Sprintf(`SELECT * FROM %s.sqlflow_test_iris_train
TO TRAIN DNNClassifier
WITH
model.n_classes = 3,
model.hidden_units = [10, 20],
train.batch_size = 10,
validation.select = "SELECT * FROM %s.sqlflow_test_iris_test"
LABEL class
INTO model_table;`, caseDB, caseDB)
_, _, _, err := connectAndRunSQL(trainSQL)
if err != nil {
a.Fail("run trainSQL error: %v", err)
Expand Down
62 changes: 62 additions & 0 deletions go/codegen/alps/codegen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2020 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package alps

import (
"bytes"
"fmt"
"strings"
"text/template"

"sqlflow.org/sqlflow/go/codegen"
"sqlflow.org/sqlflow/go/codegen/tensorflow"
"sqlflow.org/sqlflow/go/ir"
pb "sqlflow.org/sqlflow/go/proto"
)

// Train generates code to train a model using ALPS.
func Train(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
trainParams, validateParams, modelParams := tensorflow.CategorizeAttributes(trainStmt)
featureColumnsCode, fieldDescs, err := tensorflow.DeriveFeatureColumnCodeAndFieldDescs(trainStmt)
if err != nil {
return "", err
}

filler := &trainFiller{
DataSource: session.DbConnStr,
TrainSelect: trainStmt.Select,
ValidationSelect: trainStmt.ValidationSelect,
Estimator: trainStmt.Estimator,
FieldDescs: fieldDescs,
FeatureColumnCode: fmt.Sprintf("{%s}", strings.Join(featureColumnsCode, ",\n")),
Y: trainStmt.Label.GetFieldDesc()[0],
ModelParams: modelParams,
TrainParams: trainParams,
ValidationParams: validateParams,
Save: trainStmt.Into,
TmpTrainTable: trainStmt.TmpTrainTable,
TmpValidateTable: trainStmt.TmpValidateTable,
}

var program bytes.Buffer
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
"intArrayToJSONString": codegen.MarshalToJSONString,
"attrToPythonValue": tensorflow.AttrToPythonValue,
"DTypeToString": tensorflow.DTypeToString,
}).Parse(templateTrain))
if err := trainTemplate.Execute(&program, filler); err != nil {
return "", err
}
return program.String(), nil
}
41 changes: 41 additions & 0 deletions go/codegen/alps/codegen_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2020 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package alps

import (
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"

"sqlflow.org/sqlflow/go/database"
"sqlflow.org/sqlflow/go/ir"
pb "sqlflow.org/sqlflow/go/proto"
)

func mockSession() *pb.Session {
db := database.GetTestingDBSingleton()
return &pb.Session{DbConnStr: fmt.Sprintf("%s://%s", db.DriverName, db.DataSourceName)}
}

func TestALPSCodegen(t *testing.T) {
if os.Getenv("SQLFLOW_TEST_DB") != "maxcompute" {
t.Skipf("skip TestALPSCodegen and it must use when SQLFLOW_TEST_DB=maxcompute")
}
a := assert.New(t)
tir := ir.MockTrainStmt(false)
_, err := Train(tir, mockSession())
a.NoError(err)
}
168 changes: 168 additions & 0 deletions go/codegen/alps/template_train.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright 2020 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package alps

import "sqlflow.org/sqlflow/go/ir"

type trainFiller struct {
DataSource string
TrainSelect string
ValidationSelect string
Estimator string
FieldDescs map[string][]*ir.FieldDesc
FeatureColumnCode string
Y *ir.FieldDesc
ModelParams map[string]interface{}
TrainParams map[string]interface{}
ValidationParams map[string]interface{}
Save string
TmpTrainTable string
TmpValidateTable string
}

var templateTrain = `import copy
import os
import shutil

import tensorflow as tf
from alps.framework.column.column import (DenseColumn, GroupedSparseColumn,
SparseColumn)
from alps.framework.engine import LocalEngine
from alps.framework.experiment import EstimatorBuilder
from alps.io.base import OdpsConf
from runtime import db
from runtime.alps.train import train
from runtime.tensorflow.get_tf_version import tf_is_version2

feature_column_names = [{{range $target, $desclist := .FieldDescs}}{{range $desclist}}
"{{.Name}}",
{{end}}{{end}}]

feature_metas = dict()
{{ range $target, $desclist := .FieldDescs }}
{{ range $value := $desclist }}
feature_metas["{{$value.Name}}"] = {
"feature_name": "{{$value.Name}}",
"dtype": "{{$value.DType | DTypeToString}}",
"delimiter": "{{$value.Delimiter}}",
"format": "{{$value.Format}}",
"shape": {{$value.Shape | intArrayToJSONString}},
"is_sparse": "{{$value.IsSparse}}" == "true"
}
{{end}}
{{end}}

label_meta = {
"feature_name": "{{.Y.Name}}",
"dtype": "{{.Y.DType | DTypeToString}}",
"delimiter": "{{.Y.Delimiter}}",
"shape": {{.Y.Shape | intArrayToJSONString}},
"is_sparse": "{{.Y.IsSparse}}" == "true"
}

model_params=dict()
{{range $k, $v := .ModelParams}}
model_params["{{$k}}"]={{$v | attrToPythonValue}}
{{end}}

# Construct optimizer objects to pass to model initializer.
# The original model_params is serializable (do not have tf.xxx objects).
model_params_constructed = copy.deepcopy(model_params)
for optimizer_arg in ["optimizer", "dnn_optimizer", "linear_optimizer"]:
if optimizer_arg in model_params_constructed:
model_params_constructed[optimizer_arg] = eval(model_params_constructed[optimizer_arg])

if "loss" in model_params_constructed:
model_params_constructed["loss"] = eval(model_params_constructed["loss"])


class SQLFlowEstimatorBuilder(EstimatorBuilder):
def _build(self, experiment, run_config):
feature_columns_map = {{.FeatureColumnCode}}
if feature_columns_map.get("feature_columns"):
feature_columns = feature_columns_map["feature_columns"]
else:
raise ValueError("Not supported feature column map")
model_params_constructed["feature_columns"] = feature_columns
return tf.estimator.{{.Estimator}}(config=run_config,
Copy link
Collaborator

@sneaxiy sneaxiy Jul 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that the alps submitter only supports the builtin TensorFlow estimator? And doesn't it support the feature columns except DENSE and SPARSE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Custom estimators can also be supported I think.
  2. ALPS supports dense features, sparse features and grouped sparse features, all these type of feature inputs will be supported and add to unit test in incoming pull requests.
**model_params_constructed)


if __name__ == "__main__":
if tf_is_version2():
raise ValueError("ALPS must run with Tensorflow == 1.15.x")

driver, dsn = "{{.DataSource}}".split("://")
user, passwd, endpoint, odps_project = db.parseMaxComputeDSN(dsn)
odps_conf = OdpsConf(
accessid=user,
accesskey=passwd,
# endpoint should looks like: "https://service.cn.maxcompute.aliyun.com/api"
endpoint=endpoint,
project=odps_project)

features = []
for col_name in feature_column_names:
# NOTE: add sparse columns like: SparseColumn(name="deep_id", shape=[15033], dtype="int")
if feature_metas[col_name]["is_sparse"]:
features.append(SparseColumn(name=feature_metas[col_name]["feature_name"],
shape=feature_metas[col_name]["shape"],
dtype=feature_metas[col_name]["dtype"],
separator=feature_metas[col_name]["separator"]))
else:
features.append(DenseColumn(name=feature_metas[col_name]["feature_name"],
shape=feature_metas[col_name]["shape"],
dtype=feature_metas[col_name]["dtype"]))
labels = DenseColumn(name=label_meta["feature_name"],
shape=label_meta["shape"],
dtype=label_meta["dtype"])

try:
os.mkdir("scratch")
except FileExistsError:
pass

train_max_steps = {{index .TrainParams "max_steps" | attrToPythonValue}}
train_max_steps = None if train_max_steps == 0 else train_max_steps

# TODO(typhoonzero): support pass feature_map_table from WITH attributes.
# TODO(typhoonzero): pass actual use_id.
# TODO(typhoonzero): pass engine config to submit jobs to the cluster.
train(SQLFlowEstimatorBuilder(),
odps_conf=odps_conf,
project=odps_project,
train_table="{{.TmpTrainTable}}",
eval_table="{{.TmpValidateTable}}",
features=features,
labels=labels,
feature_map_table="",
feature_map_partition="",
epochs=1,
batch_size=2,
shuffle=False,
shuffle_bufsize=128,
cache_file="",
max_steps=train_max_steps,
eval_steps={{index .ValidationParams "steps" | attrToPythonValue}},
eval_batch_size=1,
eval_start_delay={{index .ValidationParams "start_delay_secs" | attrToPythonValue}},
eval_throttle={{index .ValidationParams "throttle_secs" | attrToPythonValue}},
drop_remainder=True,
export_path="./scratch/model",
scratch_dir="./scratch",
user_id="",
engine_config={"name": "LocalEngine"},
exit_on_submit=False)
shutil.rmtree("scratch")
`
Loading