Skip to content

Commit b167375

Browse files
authored
Generate couler code of workflow steps (#2806)
* wip * fix yaml generate * fix tests * fix package deps * fix pip package deps * update
1 parent 28cace1 commit b167375

File tree

9 files changed

+167
-26
lines changed

9 files changed

+167
-26
lines changed

docker/ci/install-pips.bash

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ set -e
2323
# the saved PMML file is right.
2424
pip install --quiet \
2525
numpy==1.16.2 \
26+
tensorflow-metadata==0.22.2 \
2627
tensorflow==2.0.1 \
2728
impyla==0.16.0 \
2829
pyodps==0.8.3 \

go/codegen/experimental/codegen.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,32 @@ import (
2727
// 2. generate runtime code of each statement
2828
// 3. generate couler program to form a workflow
2929
func GenerateCodeCouler(sqlProgram string, session *pb.Session) (string, error) {
30+
var defaultDockerImage = "sqlflow/sqlflow:step"
3031
stmts, err := parseToIR(sqlProgram, session)
3132
if err != nil {
3233
return "", err
3334
}
34-
for _, stmt := range stmts {
35-
stepCode, err := generateStepCode(stmt, session)
35+
stepList := []*stepContext{}
36+
for idx, stmt := range stmts {
37+
stepCode, err := generateStepCode(stmt, idx, session)
3638
if err != nil {
3739
return "", err
3840
}
39-
fmt.Println(stepCode)
41+
image := defaultDockerImage
42+
if trainStmt, ok := stmt.(*ir.TrainStmt); ok {
43+
if trainStmt.ModelImage != "" {
44+
image = trainStmt.ModelImage
45+
}
46+
}
47+
// TODO(typhoonzero): find out the image that should be used by the predict statements.
48+
step := &stepContext{
49+
Code: stepCode,
50+
Image: image,
51+
StepIndex: idx,
52+
}
53+
stepList = append(stepList, step)
4054
}
41-
return "", nil
55+
return CodeGenCouler(stepList, session)
4256
}
4357

4458
func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error) {
@@ -95,12 +109,12 @@ func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error)
95109
return result, nil
96110
}
97111

98-
func generateStepCode(stmt ir.SQLFlowStmt, session *pb.Session) (string, error) {
112+
func generateStepCode(stmt ir.SQLFlowStmt, stepIndex int, session *pb.Session) (string, error) {
99113
switch stmt.(type) {
100114
case *ir.TrainStmt:
101115
trainStmt := stmt.(*ir.TrainStmt)
102116
if strings.HasPrefix(strings.ToUpper(trainStmt.Estimator), "XGBOOST.") {
103-
return XGBoostGenerateTrain(trainStmt, session)
117+
return XGBoostGenerateTrain(trainStmt, stepIndex, session)
104118
}
105119
return "", fmt.Errorf("not implemented estimator type %s", trainStmt.Estimator)
106120
default:
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package experimental
15+
16+
import (
17+
"bytes"
18+
"fmt"
19+
"os"
20+
"strconv"
21+
"text/template"
22+
23+
pb "sqlflow.org/sqlflow/go/proto"
24+
"sqlflow.org/sqlflow/go/workflow/couler"
25+
)
26+
27+
var workflowTTL = 24 * 3600
28+
29+
type stepContext struct {
30+
Code string
31+
StepIndex int
32+
Image string
33+
}
34+
35+
type coulerFiller struct {
36+
StepList []*stepContext
37+
DataSource string
38+
StepEnvs map[string]string
39+
WorkflowTTL int
40+
SecretName string
41+
SecretData string
42+
Resources string
43+
}
44+
45+
// CodeGenCouler generate couler code to generate a workflow
46+
func CodeGenCouler(stepList []*stepContext, session *pb.Session) (string, error) {
47+
var workflowResourcesEnv = "SQLFLOW_WORKFLOW_RESOURCES"
48+
envs, err := couler.GetStepEnvs(session)
49+
if err != nil {
50+
return "", err
51+
}
52+
secretName, secretData, err := couler.GetSecret()
53+
if err != nil {
54+
return "", err
55+
}
56+
if err := couler.VerifyResources(os.Getenv(workflowResourcesEnv)); err != nil {
57+
return "", err
58+
}
59+
if os.Getenv("SQLFLOW_WORKFLOW_TTL") != "" {
60+
workflowTTL, err = strconv.Atoi(os.Getenv("SQLFLOW_WORKFLOW_TTL"))
61+
if err != nil {
62+
return "", fmt.Errorf("SQLFLOW_WORKFLOW_TTL: %s should be int", os.Getenv("SQLFLOW_WORKFLOW_TTL"))
63+
}
64+
}
65+
66+
filler := &coulerFiller{
67+
StepList: stepList,
68+
DataSource: session.DbConnStr,
69+
StepEnvs: envs,
70+
WorkflowTTL: workflowTTL,
71+
SecretName: secretName,
72+
SecretData: secretData,
73+
Resources: os.Getenv(workflowResourcesEnv),
74+
}
75+
var program bytes.Buffer
76+
if err := coulerTemplate.Execute(&program, filler); err != nil {
77+
return "", err
78+
}
79+
return program.String(), nil
80+
}
81+
82+
var coulerCodeTmpl = `
83+
import couler.argo as couler
84+
import json
85+
import re
86+
87+
datasource = "{{ .DataSource }}"
88+
89+
step_envs = dict()
90+
{{range $k, $v := .StepEnvs}}step_envs["{{$k}}"] = '''{{$v}}'''
91+
{{end}}
92+
93+
sqlflow_secret = None
94+
if "{{.SecretName}}" != "":
95+
# note(yancey1989): set dry_run to true, just reference the secret meta to generate workflow YAML,
96+
# we should create the secret before launching sqlflowserver
97+
secret_data=json.loads('''{{.SecretData}}''')
98+
sqlflow_secret = couler.secret(secret_data, name="{{ .SecretName }}", dry_run=True)
99+
100+
resources = None
101+
if '''{{.Resources}}''' != "":
102+
resources=json.loads('''{{.Resources}}''')
103+
104+
couler.clean_workflow_after_seconds_finished({{.WorkflowTTL}})
105+
106+
{{ range $ss := .StepList }}
107+
{{.Code}}
108+
couler.run_script(image="{{.Image}}", source=step_entry_{{.StepIndex}}, env=step_envs, resources=resources)
109+
{{end}}
110+
`
111+
112+
var coulerTemplate = template.Must(template.New("Couler").Parse(coulerCodeTmpl))

go/codegen/experimental/codegen_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,34 @@ package experimental
1515

1616
import (
1717
"os"
18+
"strings"
1819
"testing"
1920

21+
"github.com/stretchr/testify/assert"
2022
"sqlflow.org/sqlflow/go/database"
2123
pb "sqlflow.org/sqlflow/go/proto"
2224
)
2325

2426
func TestExperimentalXGBCodegen(t *testing.T) {
27+
a := assert.New(t)
2528
if os.Getenv("SQLFLOW_TEST_DB") != "mysql" {
2629
t.Skipf("skip TestExperimentalXGBCodegen of DB type %s", os.Getenv("SQLFLOW_TEST_DB"))
2730
}
2831
// test without COLUMN clause
2932
sql := "SELECT * FROM iris.train TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 LABEL class INTO sqlflow_models.xgb_classification;"
3033
s := &pb.Session{DbConnStr: database.GetTestingMySQLURL()}
31-
_, err := GenerateCodeCouler(sql, s)
34+
coulerCode, err := GenerateCodeCouler(sql, s)
3235
if err != nil {
3336
t.Errorf("error %s", err)
3437
}
38+
a.True(strings.Contains(coulerCode, `couler.run_script(image="sqlflow/sqlflow:step", source=step_entry_0, env=step_envs, resources=resources)`))
3539

3640
// test with COLUMN clause
3741
sql = "SELECT * FROM iris.train TO TRAIN xgboost.gbtree WITH objective=\"binary:logistic\",num_class=3 COLUMN petal_length LABEL class INTO sqlflow_models.xgb_classification;"
38-
_, err = GenerateCodeCouler(sql, s)
42+
coulerCode, err = GenerateCodeCouler(sql, s)
3943
if err != nil {
4044
t.Errorf("error %s", err)
4145
}
46+
expected := `feature_column_map = {"featuren_columns": [fc.NumericColumn(fd.FieldDesc(name="petal_length", dtype=fd.DataType.FLOAT32, delimiter="", format="", shape=[1], is_sparse=False, vocabulary=[]))]}`
47+
a.True(strings.Contains(coulerCode, expected))
4248
}

go/codegen/experimental/xgboost.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
)
2929

3030
type xgbTrainFiller struct {
31+
StepIndex int
3132
DataSource string
3233
Select string
3334
ValidationSelect string
@@ -41,8 +42,8 @@ type xgbTrainFiller struct {
4142
Submitter string
4243
}
4344

44-
// XGBoostGenerateTrain returns the step code
45-
func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
45+
// XGBoostGenerateTrain returns the step code.
46+
func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, stepIndex int, session *pb.Session) (string, error) {
4647
var err error
4748
if err = resolveModelParams(trainStmt); err != nil {
4849
return "", err
@@ -93,6 +94,7 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, session *pb.Session) (string,
9394
}
9495

9596
filler := xgbTrainFiller{
97+
StepIndex: stepIndex,
9698
DataSource: session.DbConnStr,
9799
Select: trainStmt.Select,
98100
ValidationSelect: trainStmt.ValidationSelect,
@@ -115,7 +117,7 @@ func XGBoostGenerateTrain(trainStmt *ir.TrainStmt, session *pb.Session) (string,
115117
}
116118

117119
var xgbTrainTemplate = `
118-
def step_entry():
120+
def step_entry_{{.StepIndex}}():
119121
import json
120122
import tempfile
121123
import os

go/workflow/couler/codegen.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ func fillEnvFromSession(envs *map[string]string, session *pb.Session) {
5151
fillMapIfValueNotEmpty(*envs, "SQLFLOW_submitter", session.Submitter)
5252
}
5353

54-
func getStepEnvs(session *pb.Session) (map[string]string, error) {
54+
// GetStepEnvs returns a map of envs used for couler workflow.
55+
func GetStepEnvs(session *pb.Session) (map[string]string, error) {
5556
envs := make(map[string]string)
5657
// fill step envs from the environment variables on sqlflowserver
5758
for _, env := range os.Environ() {
@@ -67,7 +68,9 @@ func getStepEnvs(session *pb.Session) (map[string]string, error) {
6768
fillEnvFromSession(&envs, session)
6869
return envs, nil
6970
}
70-
func verifyResources(resources string) error {
71+
72+
// VerifyResources verifies the SQLFLOW_WORKFLOW_RESOURCES env to be valid.
73+
func VerifyResources(resources string) error {
7174
if resources != "" {
7275
var r map[string]interface{}
7376
if e := json.Unmarshal([]byte(resources), &r); e != nil {
@@ -77,7 +80,8 @@ func verifyResources(resources string) error {
7780
return nil
7881
}
7982

80-
func getSecret() (string, string, error) {
83+
// GetSecret returns the workflow secret name, value.
84+
func GetSecret() (string, string, error) {
8185
secretMap := make(map[string]map[string]string)
8286
secretCfg := os.Getenv("SQLFLOW_WORKFLOW_SECRET")
8387
if secretCfg == "" {
@@ -99,7 +103,7 @@ func getSecret() (string, string, error) {
99103

100104
// GenFiller generates Filler to fill the template
101105
func GenFiller(programIR []ir.SQLFlowStmt, session *pb.Session) (*Filler, error) {
102-
stepEnvs, err := getStepEnvs(session)
106+
stepEnvs, err := GetStepEnvs(session)
103107
if err != nil {
104108
return nil, err
105109
}
@@ -109,11 +113,11 @@ func GenFiller(programIR []ir.SQLFlowStmt, session *pb.Session) (*Filler, error)
109113
return nil, fmt.Errorf("SQLFLOW_WORKFLOW_TTL: %s should be int", os.Getenv("SQLFLOW_WORKFLOW_TTL"))
110114
}
111115
}
112-
secretName, secretData, e := getSecret()
116+
secretName, secretData, e := GetSecret()
113117
if e != nil {
114118
return nil, e
115119
}
116-
if e := verifyResources(os.Getenv(envResource)); e != nil {
120+
if e := VerifyResources(os.Getenv(envResource)); e != nil {
117121
return nil, e
118122
}
119123

go/workflow/couler/codegen_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ func TestStringInStringSQL(t *testing.T) {
142142
yaml, e := cg.GenYAML(code)
143143
a.NoError(e)
144144
println(yaml)
145-
expect := `validation.select=\"select * from iris.train where name like \\\"Versicolor\\\";\"`
145+
expect := `validation.select=\\\"select * from iris.train where\
146+
\ name like \\\\\\\"Versicolor\\\\\\\";\\\"`
146147
a.True(strings.Contains(yaml, expect))
147148
}
148149

python/couler/couler/argo.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
import couler.pyfunc as pyfunc
2121
import pyaml
2222

23-
_wf: dict = {}
24-
_secrets: dict = {}
25-
_steps: OrderedDict = OrderedDict()
26-
_templates: dict = {}
23+
_wf = dict()
24+
_secrets = dict()
25+
_steps = OrderedDict()
26+
_templates = dict()
2727
_update_steps_lock = True
2828
_run_concurrent_lock = False
2929
_concurrent_func_line = -1
@@ -36,7 +36,7 @@
3636
# '_condition_id' records the line number where the 'couler.when()' is invoked.
3737
_condition_id = None
3838
# '_while_steps' records the step of recursive logic
39-
_while_steps: OrderedDict = OrderedDict()
39+
_while_steps = OrderedDict()
4040
# '_while_lock' indicates the recursive call starts
4141
_while_lock = False
4242
# TTL_cleaned for the workflow
@@ -506,7 +506,7 @@ def concurrent(function_list):
506506
_run_concurrent_lock = False
507507

508508

509-
def yaml():
509+
def __dump_yaml():
510510
wf = copy.deepcopy(_wf)
511511
wf["apiVersion"] = "argoproj.io/v1alpha1"
512512
wf["kind"] = "Workflow"
@@ -529,10 +529,10 @@ def yaml():
529529
def _dump_yaml():
530530
yaml_str = ""
531531
if len(_secrets) > 0:
532-
yaml_str = pyaml.dump(_secrets, string_val_style="plain")
532+
yaml_str = pyaml.dump(_secrets)
533533
yaml_str = "%s\n---\n" % yaml_str
534534
if len(_steps) > 0:
535-
yaml_str = yaml_str + pyaml.dump(yaml(), string_val_style="plain")
535+
yaml_str = yaml_str + pyaml.dump(__dump_yaml())
536536
print(yaml_str)
537537

538538

scripts/test/prepare.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ source build/env/bin/activate
2525

2626
python -m pip install --quiet \
2727
numpy==1.16.2 \
28+
tensorflow-metadata==0.22.2 \
2829
tensorflow==2.0.1 \
2930
impyla==0.16.0 \
3031
pyodps==0.8.3 \

0 commit comments

Comments
 (0)