Skip to content

Commit aab998d

Browse files
committed
feat: train model without dumping input data
1 parent c92dc38 commit aab998d

File tree

9 files changed

+276
-10
lines changed

9 files changed

+276
-10
lines changed

traindb-common/src/main/java/traindb/common/TrainDBConfiguration.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ public TrainDBConfiguration(Properties p) {
3636
this.props = p;
3737
}
3838

39-
public static String getModelRunnerPath() {
40-
return getTrainDBPrefixPath() + "/models/TrainDBModelRunner.py";
39+
public String getModelRunner() {
40+
return (String) props.getOrDefault("traindb.server.modelrunner", "file");
4141
}
4242

4343
public static String getTrainDBPrefixPath() {

traindb-core/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ limitations under the License.
127127
</exclusion>
128128
</exclusions>
129129
</dependency>
130+
<dependency>
131+
<groupId>net.sf.py4j</groupId>
132+
<artifactId>py4j</artifactId>
133+
</dependency>
130134

131135
<!-- Test Support -->
132136
<dependency>

traindb-core/src/main/java/traindb/engine/TrainDBFileModelRunner.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ public TrainDBFileModelRunner(
3838
super(conn, catalogContext, modeltypeName, modelName);
3939
}
4040

41+
public static String getModelRunnerPath() {
42+
return TrainDBConfiguration.getTrainDBPrefixPath() + "/models/TrainDBModelRunner.py";
43+
}
44+
4145
@Override
4246
public String trainModel(String schemaName, String tableName, List<String> columnNames,
4347
Map<String, Object> trainOptions) throws Exception {
@@ -64,8 +68,7 @@ public String trainModel(String schemaName, String tableName, List<String> colum
6468
MModeltype mModeltype = catalogContext.getModeltype(modeltypeName);
6569

6670
// train ML model
67-
ProcessBuilder pb = new ProcessBuilder("python",
68-
TrainDBConfiguration.getModelRunnerPath(), "train",
71+
ProcessBuilder pb = new ProcessBuilder("python", getModelRunnerPath(), "train",
6972
mModeltype.getClassName(), TrainDBConfiguration.absoluteUri(mModeltype.getUri()),
7073
dataFilename, metadataFilename, outputPath);
7174
pb.inheritIO();
@@ -86,8 +89,7 @@ public void generateSynopsis(String outputPath, int rows) throws Exception {
8689
MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype();
8790

8891
// generate synopsis from ML model
89-
ProcessBuilder pb = new ProcessBuilder("python",
90-
TrainDBConfiguration.getModelRunnerPath(), "synopsis",
92+
ProcessBuilder pb = new ProcessBuilder("python", getModelRunnerPath(), "synopsis",
9193
mModeltype.getClassName(), TrainDBConfiguration.absoluteUri(mModeltype.getUri()),
9294
modelPath, String.valueOf(rows), outputPath);
9395
pb.inheritIO();
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.engine;
16+
17+
public interface TrainDBModelRunner {
18+
19+
void init();
20+
21+
void connect(String driverClass, String url, String user, String password, String jdbcJarPath);
22+
23+
String trainModel(String sqlTrainingData, String modeltypeClass, String modelTypePath,
24+
String jsonTrainingMetadata, String modelPath);
25+
26+
void generateSynopsis(String modeltypeClass, String modeltypePath, String modelPath,
27+
int rowCount, String outputFile);
28+
29+
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.engine;
16+
17+
import static java.lang.Thread.sleep;
18+
19+
import java.io.IOException;
20+
import java.net.ServerSocket;
21+
import java.nio.file.Files;
22+
import java.nio.file.Path;
23+
import java.util.List;
24+
import java.util.Map;
25+
import org.apache.commons.dbcp2.BasicDataSource;
26+
import org.json.simple.JSONObject;
27+
import py4j.GatewayServer;
28+
import traindb.catalog.CatalogContext;
29+
import traindb.catalog.pm.MModeltype;
30+
import traindb.common.TrainDBConfiguration;
31+
import traindb.common.TrainDBException;
32+
import traindb.jdbc.TrainDBConnectionImpl;
33+
34+
public class TrainDBPy4JModelRunner extends AbstractTrainDBModelRunner {
35+
36+
public TrainDBPy4JModelRunner(
37+
TrainDBConnectionImpl conn, CatalogContext catalogContext, String modeltypeName,
38+
String modelName) {
39+
super(conn, catalogContext, modeltypeName, modelName);
40+
}
41+
42+
public static String getModelRunnerPath() {
43+
return TrainDBConfiguration.getTrainDBPrefixPath() + "/models/TrainDBPy4JModelRunner.py";
44+
}
45+
46+
@Override
47+
public String trainModel(String schemaName, String tableName, List<String> columnNames,
48+
Map<String, Object> trainOptions) throws Exception {
49+
JSONObject tableMetadata = buildTableMetadata(schemaName, tableName, columnNames, trainOptions);
50+
// write metadata for model training scripts in python
51+
Path modelPath = getModelPath();
52+
Files.createDirectories(modelPath);
53+
54+
int javaPort = getAvailablePort();
55+
int pythonPort = getAvailablePort();
56+
57+
// train ML model
58+
ProcessBuilder pb = new ProcessBuilder("python", getModelRunnerPath(),
59+
String.valueOf(javaPort), String.valueOf(pythonPort));
60+
pb.inheritIO();
61+
Process process = pb.start();
62+
63+
sleep(1000); // FIXME waiting the python process
64+
GatewayServer server = new GatewayServer(null, javaPort, pythonPort, 0, 0, null);
65+
server.start();
66+
67+
TrainDBModelRunner modelRunner = (TrainDBModelRunner) server.getPythonServerEntryPoint(
68+
new Class[] { TrainDBModelRunner.class });
69+
70+
BasicDataSource ds = conn.getDataSource();
71+
Class jdbcClass = Class.forName(ds.getDriverClassName());
72+
MModeltype mModeltype = catalogContext.getModeltype(modeltypeName);
73+
String trainInfo;
74+
try {
75+
modelRunner.connect(
76+
ds.getDriverClassName(), ds.getUrl(), ds.getUsername(), ds.getPassword(),
77+
jdbcClass.getProtectionDomain().getCodeSource().getLocation().getPath());
78+
trainInfo = modelRunner.trainModel(
79+
buildSelectTrainingDataQuery(schemaName, tableName, columnNames),
80+
mModeltype.getClassName(), TrainDBConfiguration.absoluteUri(mModeltype.getUri()),
81+
tableMetadata.toJSONString(), modelPath.toString());
82+
} catch (Exception e) {
83+
server.shutdown();
84+
process.destroy();
85+
e.printStackTrace();
86+
throw new TrainDBException("failed to train model");
87+
}
88+
server.shutdown();
89+
process.destroy();
90+
91+
return trainInfo;
92+
}
93+
94+
public void generateSynopsis(String outputPath, int rows) throws Exception {
95+
String modelPath = getModelPath().toString();
96+
MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype();
97+
98+
// generate synopsis from ML model
99+
int javaPort = getAvailablePort();
100+
int pythonPort = getAvailablePort();
101+
102+
// train ML model
103+
ProcessBuilder pb = new ProcessBuilder("python", getModelRunnerPath(),
104+
String.valueOf(javaPort), String.valueOf(pythonPort));
105+
pb.inheritIO();
106+
Process process = pb.start();
107+
108+
sleep(1000); // FIXME waiting the python process
109+
GatewayServer server = new GatewayServer(null, javaPort, pythonPort, 0, 0, null);
110+
server.start();
111+
112+
TrainDBModelRunner modelRunner = (TrainDBModelRunner) server.getPythonServerEntryPoint(
113+
new Class[] { TrainDBModelRunner.class });
114+
115+
try {
116+
modelRunner.generateSynopsis(
117+
mModeltype.getClassName(), TrainDBConfiguration.absoluteUri(mModeltype.getUri()),
118+
modelPath, rows, outputPath);
119+
} catch (Exception e) {
120+
server.shutdown();
121+
process.destroy();
122+
e.printStackTrace();
123+
throw new TrainDBException("failed to create synopsis");
124+
}
125+
server.shutdown();
126+
process.destroy();
127+
}
128+
129+
private int getAvailablePort() throws Exception {
130+
ServerSocket s;
131+
try {
132+
s = new ServerSocket(0);
133+
} catch (IOException e) {
134+
throw new TrainDBException("failed to get an available port");
135+
}
136+
return s.getLocalPort();
137+
}
138+
}

traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ public void trainModel(
8383
schemaName = conn.getSchema();
8484
}
8585

86-
AbstractTrainDBModelRunner runner =
87-
new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName);
86+
AbstractTrainDBModelRunner runner = createModelRunner(modeltypeName, modelName);
8887
String trainInfo = runner.trainModel(schemaName, tableName, columnNames, trainOptions);
8988

9089
JSONParser jsonParser = new JSONParser();
@@ -181,6 +180,15 @@ private void loadSynopsisIntoTable(String synopsisName, MModel mModel,
181180
}
182181
}
183182

183+
private AbstractTrainDBModelRunner createModelRunner(String modeltypeName, String modelName) {
184+
String modelrunner = conn.cfg.getModelRunner();
185+
if (modelrunner.equals("py4j")) {
186+
return new TrainDBPy4JModelRunner(conn, catalogContext, modeltypeName, modelName);
187+
}
188+
189+
return new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName);
190+
}
191+
184192
@Override
185193
public void createSynopsis(String synopsisName, String modelName, int limitNumber)
186194
throws Exception {
@@ -194,8 +202,7 @@ public void createSynopsis(String synopsisName, String modelName, int limitNumbe
194202
MModel mModel = catalogContext.getModel(modelName);
195203
MModeltype mModeltype = mModel.getModeltype();
196204

197-
AbstractTrainDBModelRunner runner =
198-
new TrainDBFileModelRunner(conn, catalogContext, mModeltype.getName(), modelName);
205+
AbstractTrainDBModelRunner runner = createModelRunner(mModeltype.getName(), modelName);
199206
String outputPath = runner.getModelPath().toString() + '/' + synopsisName + ".csv";
200207
runner.generateSynopsis(outputPath, limitNumber);
201208

traindb-core/src/main/java/traindb/jdbc/TrainDBConnectionImpl.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ private BasicDataSource dataSource(String url, Properties info) {
198198
return dataSource;
199199
}
200200

201+
public BasicDataSource getDataSource() {
202+
return dataSource;
203+
}
204+
201205
private Connection extraConnection() {
202206
try {
203207
return dataSource.getConnection();
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License.
13+
"""
14+
15+
import importlib
16+
import json
17+
import os
18+
import jaydebeapi
19+
from py4j.java_gateway import JavaGateway, GatewayParameters, CallbackServerParameters
20+
import pandas as pd
21+
22+
class TrainDBModelRunner(object):
23+
24+
class Java:
25+
implements = [ "traindb.engine.TrainDBModelRunner" ]
26+
27+
def init(self, java_port, python_port):
28+
self.gateway = JavaGateway(
29+
gateway_parameters = GatewayParameters(port=java_port),
30+
callback_server_parameters = CallbackServerParameters(port=python_port),
31+
python_server_entry_point = self)
32+
33+
def connect(self, driver_class_name, url, user, password, jdbc_jar_path):
34+
self.conn = jaydebeapi.connect(
35+
driver_class_name, url, [ user, password ], jdbc_jar_path)
36+
37+
def trainModel(self, sql_training_data, modeltype_class, modeltype_path, training_metadata, model_path, args=[], kwargs={}):
38+
curs = self.conn.cursor()
39+
curs.execute(sql_training_data)
40+
header = [desc[0] for desc in curs.description]
41+
data = pd.DataFrame(curs.fetchall(), columns=header)
42+
metadata = json.loads(training_metadata)
43+
44+
mod = self._load_module(modeltype_class, modeltype_path)
45+
model = getattr(mod, modeltype_class)(*args, **metadata['options'])
46+
model.train(data, metadata)
47+
model.save(model_path)
48+
49+
train_info = {}
50+
train_info['base_table_rows'] = len(data.index)
51+
train_info['trained_rows'] = len(data.index)
52+
return json.dumps(train_info)
53+
54+
def generateSynopsis(self, modeltype_class, modeltype_path, model_path, row_count, output_file):
55+
mod = self._load_module(modeltype_class, modeltype_path)
56+
model = getattr(mod, modeltype_class)()
57+
model.load(model_path)
58+
syn_data = model.synopsis(row_count)
59+
syn_data.to_csv(output_file, index=False)
60+
61+
def _load_module(self, modeltype_class, modeltype_path):
62+
spec = importlib.util.spec_from_file_location(modeltype_class, modeltype_path)
63+
mod = importlib.util.module_from_spec(spec)
64+
spec.loader.exec_module(mod)
65+
66+
return mod
67+
68+
69+
import argparse
70+
71+
root_parser = argparse.ArgumentParser(description='TrainDB Model Runner')
72+
root_parser.add_argument('java_port', type=int)
73+
root_parser.add_argument('python_port', type=int)
74+
args = root_parser.parse_args()
75+
76+
runner = TrainDBModelRunner()
77+
runner.init(args.java_port, args.python_port)

traindb-project/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ limitations under the License.
227227
<artifactId>jdo-api</artifactId>
228228
<version>3.1</version>
229229
</dependency>
230+
<dependency>
231+
<groupId>net.sf.py4j</groupId>
232+
<artifactId>py4j</artifactId>
233+
<version>0.10.9.7</version>
234+
</dependency>
230235

231236
<!-- Test Support -->
232237
<dependency>

0 commit comments

Comments
 (0)