Skip to content

Commit 66f606e

Browse files
authored
alps submitter and codegen (#2771)
* test alps submitter * add alps codegen * update
1 parent c7217ee commit 66f606e

File tree

9 files changed

+383
-87
lines changed

9 files changed

+383
-87
lines changed

go/cmd/sqlflowserver/e2e_alps_test.go

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -49,47 +49,28 @@ func TestEnd2EndMaxComputeALPS(t *testing.T) {
4949
if caseDB == "" {
5050
t.Fatalf("Must set env SQLFLOW_TEST_DB_MAXCOMPUTE_PROJECT when testing ALPS cases (SQLFLOW_submitter=alps)!!")
5151
}
52-
err = prepareTestData(dbConnStr)
53-
if err != nil {
54-
t.Fatalf("prepare test dataset failed: %v", err)
55-
}
5652

5753
go start(modelDir, caCrt, caKey, unitTestPort, false)
5854
server.WaitPortReady(fmt.Sprintf("localhost:%d", unitTestPort), 0)
5955

6056
t.Run("CaseTrainALPS", CaseTrainALPS)
61-
t.Run("CaseTrainALPSFeatureMap", CaseTrainALPSFeatureMap)
62-
t.Run("CaseTrainALPSRemoteModel", CaseTrainALPSRemoteModel)
57+
// TODO(typhoonzero): add this back later
58+
// t.Run("CaseTrainALPSFeatureMap", CaseTrainALPSFeatureMap)
59+
// t.Run("CaseTrainALPSRemoteModel", CaseTrainALPSRemoteModel)
6360
}
6461

6562
// CaseTrainALPS is a case for training models using ALPS with out feature_map table
6663
func CaseTrainALPS(t *testing.T) {
6764
a := assert.New(t)
68-
trainSQL := fmt.Sprintf(`
69-
SELECT deep_id, user_space_stat, user_behavior_stat, space_stat, l
70-
FROM %s.sparse_column_test
71-
LIMIT 100
72-
TO TRAIN DNNClassifier
73-
WITH
74-
model.n_classes = 2,
75-
model.hidden_units = [10, 20],
76-
train.batch_size = 10,
77-
engine.ps_num = 0,
78-
engine.worker_num = 0,
79-
engine.type = local,
80-
validation.table = "%s.sparse_column_test"
81-
COLUMN
82-
SPARSE(deep_id,15033,COMMA,int),
83-
SPARSE(user_space_stat,310,COMMA,int),
84-
SPARSE(user_behavior_stat,511,COMMA,int),
85-
SPARSE(space_stat,418,COMMA,int),
86-
EMBEDDING(CATEGORY_ID(deep_id,15033,COMMA),512,mean),
87-
EMBEDDING(CATEGORY_ID(user_space_stat,310,COMMA),64,mean),
88-
EMBEDDING(CATEGORY_ID(user_behavior_stat,511,COMMA),64,mean),
89-
EMBEDDING(CATEGORY_ID(space_stat,418,COMMA),64,mean)
90-
LABEL l
91-
INTO model_table;
92-
`, caseDB, caseDB)
65+
trainSQL := fmt.Sprintf(`SELECT * FROM %s.sqlflow_test_iris_train
66+
TO TRAIN DNNClassifier
67+
WITH
68+
model.n_classes = 3,
69+
model.hidden_units = [10, 20],
70+
train.batch_size = 10,
71+
validation.select = "SELECT * FROM %s.sqlflow_test_iris_test"
72+
LABEL class
73+
INTO model_table;`, caseDB, caseDB)
9374
_, _, _, err := connectAndRunSQL(trainSQL)
9475
if err != nil {
9576
a.Fail("run trainSQL error: %v", err)

go/codegen/alps/codegen.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package alps
15+
16+
import (
17+
"bytes"
18+
"fmt"
19+
"strings"
20+
"text/template"
21+
22+
"sqlflow.org/sqlflow/go/codegen"
23+
"sqlflow.org/sqlflow/go/codegen/tensorflow"
24+
"sqlflow.org/sqlflow/go/ir"
25+
pb "sqlflow.org/sqlflow/go/proto"
26+
)
27+
28+
// Train generates code to train a model using ALPS.
29+
func Train(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
30+
trainParams, validateParams, modelParams := tensorflow.CategorizeAttributes(trainStmt)
31+
featureColumnsCode, fieldDescs, err := tensorflow.DeriveFeatureColumnCodeAndFieldDescs(trainStmt)
32+
if err != nil {
33+
return "", err
34+
}
35+
36+
filler := &trainFiller{
37+
DataSource: session.DbConnStr,
38+
TrainSelect: trainStmt.Select,
39+
ValidationSelect: trainStmt.ValidationSelect,
40+
Estimator: trainStmt.Estimator,
41+
FieldDescs: fieldDescs,
42+
FeatureColumnCode: fmt.Sprintf("{%s}", strings.Join(featureColumnsCode, ",\n")),
43+
Y: trainStmt.Label.GetFieldDesc()[0],
44+
ModelParams: modelParams,
45+
TrainParams: trainParams,
46+
ValidationParams: validateParams,
47+
Save: trainStmt.Into,
48+
TmpTrainTable: trainStmt.TmpTrainTable,
49+
TmpValidateTable: trainStmt.TmpValidateTable,
50+
}
51+
52+
var program bytes.Buffer
53+
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
54+
"intArrayToJSONString": codegen.MarshalToJSONString,
55+
"attrToPythonValue": tensorflow.AttrToPythonValue,
56+
"DTypeToString": tensorflow.DTypeToString,
57+
}).Parse(templateTrain))
58+
if err := trainTemplate.Execute(&program, filler); err != nil {
59+
return "", err
60+
}
61+
return program.String(), nil
62+
}

go/codegen/alps/codegen_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package alps
15+
16+
import (
17+
"fmt"
18+
"os"
19+
"testing"
20+
21+
"github.com/stretchr/testify/assert"
22+
23+
"sqlflow.org/sqlflow/go/database"
24+
"sqlflow.org/sqlflow/go/ir"
25+
pb "sqlflow.org/sqlflow/go/proto"
26+
)
27+
28+
func mockSession() *pb.Session {
29+
db := database.GetTestingDBSingleton()
30+
return &pb.Session{DbConnStr: fmt.Sprintf("%s://%s", db.DriverName, db.DataSourceName)}
31+
}
32+
33+
func TestALPSCodegen(t *testing.T) {
34+
if os.Getenv("SQLFLOW_TEST_DB") != "maxcompute" {
35+
t.Skipf("skip TestALPSCodegen and it must use when SQLFLOW_TEST_DB=maxcompute")
36+
}
37+
a := assert.New(t)
38+
tir := ir.MockTrainStmt(false)
39+
_, err := Train(tir, mockSession())
40+
a.NoError(err)
41+
}

go/codegen/alps/template_train.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package alps
15+
16+
import "sqlflow.org/sqlflow/go/ir"
17+
18+
type trainFiller struct {
19+
DataSource string
20+
TrainSelect string
21+
ValidationSelect string
22+
Estimator string
23+
FieldDescs map[string][]*ir.FieldDesc
24+
FeatureColumnCode string
25+
Y *ir.FieldDesc
26+
ModelParams map[string]interface{}
27+
TrainParams map[string]interface{}
28+
ValidationParams map[string]interface{}
29+
Save string
30+
TmpTrainTable string
31+
TmpValidateTable string
32+
}
33+
34+
var templateTrain = `import copy
35+
import os
36+
import shutil
37+
38+
import tensorflow as tf
39+
from alps.framework.column.column import (DenseColumn, GroupedSparseColumn,
40+
SparseColumn)
41+
from alps.framework.engine import LocalEngine
42+
from alps.framework.experiment import EstimatorBuilder
43+
from alps.io.base import OdpsConf
44+
from runtime import db
45+
from runtime.alps.train import train
46+
from runtime.tensorflow.get_tf_version import tf_is_version2
47+
48+
feature_column_names = [{{range $target, $desclist := .FieldDescs}}{{range $desclist}}
49+
"{{.Name}}",
50+
{{end}}{{end}}]
51+
52+
feature_metas = dict()
53+
{{ range $target, $desclist := .FieldDescs }}
54+
{{ range $value := $desclist }}
55+
feature_metas["{{$value.Name}}"] = {
56+
"feature_name": "{{$value.Name}}",
57+
"dtype": "{{$value.DType | DTypeToString}}",
58+
"delimiter": "{{$value.Delimiter}}",
59+
"format": "{{$value.Format}}",
60+
"shape": {{$value.Shape | intArrayToJSONString}},
61+
"is_sparse": "{{$value.IsSparse}}" == "true"
62+
}
63+
{{end}}
64+
{{end}}
65+
66+
label_meta = {
67+
"feature_name": "{{.Y.Name}}",
68+
"dtype": "{{.Y.DType | DTypeToString}}",
69+
"delimiter": "{{.Y.Delimiter}}",
70+
"shape": {{.Y.Shape | intArrayToJSONString}},
71+
"is_sparse": "{{.Y.IsSparse}}" == "true"
72+
}
73+
74+
model_params=dict()
75+
{{range $k, $v := .ModelParams}}
76+
model_params["{{$k}}"]={{$v | attrToPythonValue}}
77+
{{end}}
78+
79+
# Construct optimizer objects to pass to model initializer.
80+
# The original model_params is serializable (do not have tf.xxx objects).
81+
model_params_constructed = copy.deepcopy(model_params)
82+
for optimizer_arg in ["optimizer", "dnn_optimizer", "linear_optimizer"]:
83+
if optimizer_arg in model_params_constructed:
84+
model_params_constructed[optimizer_arg] = eval(model_params_constructed[optimizer_arg])
85+
86+
if "loss" in model_params_constructed:
87+
model_params_constructed["loss"] = eval(model_params_constructed["loss"])
88+
89+
90+
class SQLFlowEstimatorBuilder(EstimatorBuilder):
91+
def _build(self, experiment, run_config):
92+
feature_columns_map = {{.FeatureColumnCode}}
93+
if feature_columns_map.get("feature_columns"):
94+
feature_columns = feature_columns_map["feature_columns"]
95+
else:
96+
raise ValueError("Not supported feature column map")
97+
model_params_constructed["feature_columns"] = feature_columns
98+
return tf.estimator.{{.Estimator}}(config=run_config,
99+
**model_params_constructed)
100+
101+
102+
if __name__ == "__main__":
103+
if tf_is_version2():
104+
raise ValueError("ALPS must run with Tensorflow == 1.15.x")
105+
106+
driver, dsn = "{{.DataSource}}".split("://")
107+
user, passwd, endpoint, odps_project = db.parseMaxComputeDSN(dsn)
108+
odps_conf = OdpsConf(
109+
accessid=user,
110+
accesskey=passwd,
111+
# endpoint should looks like: "https://service.cn.maxcompute.aliyun.com/api"
112+
endpoint=endpoint,
113+
project=odps_project)
114+
115+
features = []
116+
for col_name in feature_column_names:
117+
# NOTE: add sparse columns like: SparseColumn(name="deep_id", shape=[15033], dtype="int")
118+
if feature_metas[col_name]["is_sparse"]:
119+
features.append(SparseColumn(name=feature_metas[col_name]["feature_name"],
120+
shape=feature_metas[col_name]["shape"],
121+
dtype=feature_metas[col_name]["dtype"],
122+
separator=feature_metas[col_name]["separator"]))
123+
else:
124+
features.append(DenseColumn(name=feature_metas[col_name]["feature_name"],
125+
shape=feature_metas[col_name]["shape"],
126+
dtype=feature_metas[col_name]["dtype"]))
127+
labels = DenseColumn(name=label_meta["feature_name"],
128+
shape=label_meta["shape"],
129+
dtype=label_meta["dtype"])
130+
131+
try:
132+
os.mkdir("scratch")
133+
except FileExistsError:
134+
pass
135+
136+
train_max_steps = {{index .TrainParams "max_steps" | attrToPythonValue}}
137+
train_max_steps = None if train_max_steps == 0 else train_max_steps
138+
139+
# TODO(typhoonzero): support pass feature_map_table from WITH attributes.
140+
# TODO(typhoonzero): pass actual use_id.
141+
# TODO(typhoonzero): pass engine config to submit jobs to the cluster.
142+
train(SQLFlowEstimatorBuilder(),
143+
odps_conf=odps_conf,
144+
project=odps_project,
145+
train_table="{{.TmpTrainTable}}",
146+
eval_table="{{.TmpValidateTable}}",
147+
features=features,
148+
labels=labels,
149+
feature_map_table="",
150+
feature_map_partition="",
151+
epochs=1,
152+
batch_size=2,
153+
shuffle=False,
154+
shuffle_bufsize=128,
155+
cache_file="",
156+
max_steps=train_max_steps,
157+
eval_steps={{index .ValidationParams "steps" | attrToPythonValue}},
158+
eval_batch_size=1,
159+
eval_start_delay={{index .ValidationParams "start_delay_secs" | attrToPythonValue}},
160+
eval_throttle={{index .ValidationParams "throttle_secs" | attrToPythonValue}},
161+
drop_remainder=True,
162+
export_path="./scratch/model",
163+
scratch_dir="./scratch",
164+
user_id="",
165+
engine_config={"name": "LocalEngine"},
166+
exit_on_submit=False)
167+
shutil.rmtree("scratch")
168+
`

0 commit comments

Comments
 (0)