Skip to content

Commit c460d30

Browse files
authored
Merge pull request #40 from traindb-project/dev/issue-39
Feat: Export/Import Model
2 parents 66bf276 + dade9bd commit c460d30

24 files changed

+707
-11
lines changed

traindb-catalog/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ limitations under the License.
5454
<groupId>org.xerial</groupId>
5555
<artifactId>sqlite-jdbc</artifactId>
5656
</dependency>
57+
<dependency>
58+
<groupId>com.googlecode.json-simple</groupId>
59+
<artifactId>json-simple</artifactId>
60+
</dependency>
5761
</dependencies>
5862

5963
<build>

traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Map;
2020
import org.apache.calcite.rel.type.RelDataType;
2121
import org.checkerframework.checker.nullness.qual.Nullable;
22+
import org.json.simple.JSONObject;
2223
import traindb.catalog.pm.MModel;
2324
import traindb.catalog.pm.MModeltype;
2425
import traindb.catalog.pm.MQueryLog;
@@ -68,6 +69,9 @@ Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns
6869

6970
void updateTrainingStatus(String modelName, String status) throws CatalogException;
7071

72+
void importModel(String modeltypeName, String modelName, JSONObject exportMetadata)
73+
throws CatalogException;
74+
7175
/* Synopsis */
7276
MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio)
7377
throws CatalogException;

traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.apache.calcite.rel.type.RelDataType;
2828
import org.apache.calcite.rel.type.RelDataTypeField;
2929
import org.checkerframework.checker.nullness.qual.Nullable;
30+
import org.json.simple.JSONArray;
31+
import org.json.simple.JSONObject;
3032
import traindb.catalog.pm.MColumn;
3133
import traindb.catalog.pm.MModel;
3234
import traindb.catalog.pm.MModeltype;
@@ -117,14 +119,15 @@ public MModel trainModel(
117119
String modeltypeName, String modelName, String schemaName, String tableName,
118120
List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
119121
@Nullable Long trainedRows, @Nullable String options) throws CatalogException {
122+
MTable mTable;
120123
try {
121124
MSchema mSchema = getSchema(schemaName);
122125
if (mSchema == null) {
123126
mSchema = new MSchema(schemaName);
124127
pm.makePersistent(mSchema);
125128
}
126129

127-
MTable mTable = getTable(schemaName, tableName);
130+
mTable = getTable(schemaName, tableName);
128131
if (mTable == null) {
129132
mTable = new MTable(tableName, "TABLE", mSchema);
130133
pm.makePersistent(mTable);
@@ -146,7 +149,7 @@ public MModel trainModel(
146149
MModeltype mModeltype = getModeltype(modeltypeName);
147150
MModel mModel = new MModel(
148151
mModeltype, modelName, schemaName, tableName, columnNames,
149-
baseTableRows, trainedRows, options == null ? "" : options);
152+
baseTableRows, trainedRows, options == null ? "" : options, mTable);
150153
pm.makePersistent(mModel);
151154

152155
if (mModeltype.getLocation().equals("REMOTE")) {
@@ -274,6 +277,54 @@ public void updateTrainingStatus(String modelName, String status) throws Catalog
274277
}
275278
}
276279

280+
@Override
281+
public void importModel(String modeltypeName, String modelName, JSONObject exportMetadata)
282+
throws CatalogException {
283+
try {
284+
String schemaName = (String) exportMetadata.get("schemaName");
285+
MSchema mSchema = getSchema(schemaName);
286+
if (mSchema == null) {
287+
mSchema = pm.makePersistent(new MSchema(schemaName));
288+
}
289+
290+
String tableName = (String) exportMetadata.get("tableName");
291+
MTable mTable = getTable(schemaName, tableName);
292+
if (mTable == null) {
293+
JSONObject jsonTable = (JSONObject) exportMetadata.get("table");
294+
mTable = pm.makePersistent(
295+
new MTable(tableName, (String) jsonTable.get("tableType"), mSchema));
296+
297+
JSONArray jsonColumns = (JSONArray) jsonTable.get("columns");
298+
for (int i = 0; i < jsonColumns.size(); i++) {
299+
JSONObject jsonColumn = (JSONObject) jsonColumns.get(i);
300+
String columnName = (String) jsonColumn.get("columnName");
301+
int columnType = ((Long) jsonColumn.get("columnType")).intValue();
302+
int precision = ((Long) jsonColumn.get("precision")).intValue();
303+
int scale = ((Long) jsonColumn.get("scale")).intValue();
304+
boolean nullable = (boolean) jsonColumn.get("nullable");
305+
MColumn mColumn = new MColumn(columnName, columnType, precision, scale, nullable, mTable);
306+
pm.makePersistent(mColumn);
307+
}
308+
}
309+
310+
MModeltype mModeltype = getModeltype(modeltypeName);
311+
JSONArray jsonColumnNames = (JSONArray) exportMetadata.get("columnNames");
312+
List<String> columnNames = new ArrayList<>();
313+
for (int i = 0; i < jsonColumnNames.size(); i++) {
314+
columnNames.add((String) jsonColumnNames.get(i));
315+
}
316+
String options = (String) exportMetadata.get("modelOptions");
317+
Long baseTableRows = (Long) exportMetadata.get("tableRows");
318+
Long trainedRows = (Long) exportMetadata.get("trainedRows");
319+
MModel mModel = new MModel(
320+
mModeltype, modelName, schemaName, tableName, columnNames,
321+
baseTableRows, trainedRows, options == null ? "" : options, mTable);
322+
pm.makePersistent(mModel);
323+
} catch (RuntimeException e) {
324+
throw new CatalogException("failed to import model '" + modelName + "'", e);
325+
}
326+
}
327+
277328
@Override
278329
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
279330
@Nullable Double ratio) throws CatalogException {

traindb-catalog/src/main/java/traindb/catalog/pm/MColumn.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package traindb.catalog.pm;
1616

17+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
1718
import javax.jdo.annotations.Column;
1819
import javax.jdo.annotations.IdGeneratorStrategy;
1920
import javax.jdo.annotations.PersistenceCapable;
@@ -22,6 +23,7 @@
2223
import traindb.catalog.CatalogConstants;
2324

2425
@PersistenceCapable
26+
@JsonIgnoreProperties({ "table" })
2527
public final class MColumn {
2628
@PrimaryKey
2729
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)

traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,13 @@ public final class MModel {
6363
@Persistent(mappedBy = "model", dependentElement = "true")
6464
private Collection<MTrainingStatus> training_status;
6565

66+
@Persistent(dependent = "false")
67+
private MTable table;
68+
6669
public MModel(
6770
MModeltype modeltype, String modelName, String schemaName, String tableName,
6871
List<String> columns, @Nullable Long baseTableRows, @Nullable Long trainedRows,
69-
String options) {
72+
String options, MTable table) {
7073
this.modeltype = modeltype;
7174
this.model_name = modelName;
7275
this.schema_name = schemaName;
@@ -75,6 +78,7 @@ public MModel(
7578
this.table_rows = (baseTableRows == null) ? 0 : baseTableRows;
7679
this.trained_rows = (trainedRows == null) ? 0 : trainedRows;
7780
this.model_options = options.getBytes();
81+
this.table = table;
7882
}
7983

8084
public String getModelName() {
@@ -113,6 +117,10 @@ public Collection<MTrainingStatus> trainingStatus() {
113117
return training_status;
114118
}
115119

120+
public MTable getTable() {
121+
return table;
122+
}
123+
116124
public boolean isEnabled() {
117125
if (training_status.isEmpty() || training_status.size() == 0) {
118126
return true;

traindb-catalog/src/main/java/traindb/catalog/pm/MTable.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package traindb.catalog.pm;
1616

17+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
1718
import java.util.ArrayList;
1819
import java.util.Collection;
1920
import javax.jdo.annotations.Column;
@@ -24,6 +25,7 @@
2425
import traindb.catalog.CatalogConstants;
2526

2627
@PersistenceCapable
28+
@JsonIgnoreProperties({ "schema" })
2729
public final class MTable {
2830
@PrimaryKey
2931
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)

traindb-common/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ limitations under the License.
5454
<groupId>org.apache.calcite</groupId>
5555
<artifactId>calcite-core</artifactId>
5656
</dependency>
57+
<dependency>
58+
<groupId>commons-io</groupId>
59+
<artifactId>commons-io</artifactId>
60+
</dependency>
5761
<dependency>
5862
<groupId>sqlline</groupId>
5963
<artifactId>sqlline</artifactId>
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.util;
16+
17+
import java.io.ByteArrayInputStream;
18+
import java.io.File;
19+
import java.io.FileOutputStream;
20+
import java.io.IOException;
21+
import java.io.Writer;
22+
import java.net.URI;
23+
import java.nio.file.FileSystem;
24+
import java.nio.file.FileSystems;
25+
import java.nio.file.Files;
26+
import java.nio.file.Path;
27+
import java.nio.file.Paths;
28+
import java.nio.file.StandardCopyOption;
29+
import java.nio.file.StandardOpenOption;
30+
import java.util.HashMap;
31+
import java.util.Map;
32+
import java.util.zip.ZipEntry;
33+
import java.util.zip.ZipInputStream;
34+
import java.util.zip.ZipOutputStream;
35+
import org.apache.commons.io.IOUtils;
36+
37+
public final class ZipUtils {
38+
39+
private ZipUtils() {
40+
}
41+
42+
public static void pack(String sourceDirPath, String zipFilePath) throws IOException {
43+
Path zp = Files.createFile(Paths.get(zipFilePath));
44+
try (ZipOutputStream zs = new ZipOutputStream(Files.newOutputStream(zp))) {
45+
Path sp = Paths.get(sourceDirPath);
46+
Files.walk(sp)
47+
.filter(path -> !Files.isDirectory(path))
48+
.forEach(path -> {
49+
ZipEntry zipEntry = new ZipEntry(sp.relativize(path).toString());
50+
try {
51+
zs.putNextEntry(zipEntry);
52+
Files.copy(path, zs);
53+
zs.closeEntry();
54+
} catch (IOException e) {
55+
throw new RuntimeException(e);
56+
}
57+
});
58+
}
59+
}
60+
61+
public static void addFileToZip(Path file, Path zip) throws IOException {
62+
Map<String, String> env = new HashMap<>();
63+
env.put("create", "false");
64+
65+
URI uri = URI.create("jar:file:" + zip.toString());
66+
try (FileSystem fs = FileSystems.newFileSystem(uri, env)) {
67+
Path p = fs.getPath(file.getFileName().toString());
68+
Files.copy(file, p, StandardCopyOption.REPLACE_EXISTING);
69+
}
70+
}
71+
72+
public static void addNewFileFromStringToZip(String newFilename, String contents, Path zip)
73+
throws IOException {
74+
Map<String, String> env = new HashMap<>();
75+
env.put("create", "false");
76+
77+
URI uri = URI.create("jar:file:" + zip.toString());
78+
try (FileSystem fs = FileSystems.newFileSystem(uri, env)) {
79+
Path p = fs.getPath(newFilename);
80+
try (Writer writer = Files.newBufferedWriter(p, StandardOpenOption.CREATE)) {
81+
writer.write(contents);
82+
}
83+
}
84+
}
85+
86+
public static byte[] extractZipEntry(byte[] content, String filename) throws IOException {
87+
ZipInputStream zis = null;
88+
byte[] bytes = null;
89+
try {
90+
zis = new ZipInputStream(new ByteArrayInputStream(content));
91+
ZipEntry zipEntry;
92+
while ((zipEntry = zis.getNextEntry()) != null) {
93+
if (zipEntry.getName().equals(filename)) {
94+
bytes = IOUtils.readFully(zis, (int) zipEntry.getSize());
95+
break;
96+
}
97+
}
98+
} finally {
99+
if (zis != null) {
100+
zis.close();
101+
}
102+
}
103+
return bytes;
104+
}
105+
106+
public static void unpack(byte[] content, String outputPath) throws IOException {
107+
ZipInputStream zis = null;
108+
109+
try {
110+
File dir = new File(outputPath);
111+
if (!dir.exists()) {
112+
dir.mkdirs();
113+
}
114+
byte[] buffer = new byte[8192];
115+
zis = new ZipInputStream(new ByteArrayInputStream(content));
116+
ZipEntry zipEntry;
117+
while ((zipEntry = zis.getNextEntry()) != null) {
118+
String fileName = zipEntry.getName();
119+
File newFile = new File(outputPath + File.separator + fileName);
120+
new File(newFile.getParent()).mkdirs(); //create directories for sub directories in zip
121+
FileOutputStream fos = new FileOutputStream(newFile);
122+
int len;
123+
while ((len = zis.read(buffer)) > 0) {
124+
fos.write(buffer, 0, len);
125+
}
126+
fos.close();
127+
zis.closeEntry();
128+
}
129+
} finally {
130+
if (zis != null) {
131+
zis.close();
132+
}
133+
}
134+
}
135+
}

traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ traindbStmts
4141
| bypassDdlStmt
4242
| deleteQueryLogs
4343
| deleteTasks
44+
| exportModel
45+
| importModel
4446
;
4547

4648
createModeltype
@@ -114,6 +116,18 @@ optionValue
114116
| NUMERIC_LITERAL
115117
;
116118

119+
exportModel
120+
: K_EXPORT K_MODEL modelName
121+
;
122+
123+
importModel
124+
: K_IMPORT K_MODEL modelName K_FROM modelBinaryString
125+
;
126+
127+
modelBinaryString
128+
: BINARY_STRING_LITERAL
129+
;
130+
117131
showStmt
118132
: K_SHOW showTargets showWhereClause?
119133
;
@@ -219,9 +233,11 @@ K_DELETE : D E L E T E ;
219233
K_DESC : D E S C ;
220234
K_DESCRIBE : D E S C R I B E ;
221235
K_DROP : D R O P ;
236+
K_EXPORT : E X P O R T ;
222237
K_FOR : F O R ;
223238
K_FROM : F R O M ;
224239
K_HYPERPARAMETERS : H Y P E R P A R A M E T E R S ;
240+
K_IMPORT : I M P O R T ;
225241
K_IN : I N ;
226242
K_INFERENCE : I N F E R E N C E ;
227243
K_LIKE : L I K E ;
@@ -277,6 +293,17 @@ STRING_LITERAL
277293
}
278294
;
279295

296+
BINARY_STRING_LITERAL
297+
: 'x' '\'' HEXDIGIT* '\''
298+
{
299+
setText(getText().substring(2, getText().length() - 1));
300+
}
301+
;
302+
303+
HEXDIGIT
304+
: [a-fA-F0-9]
305+
;
306+
280307
WHITESPACES : [ \t\r\n]+ -> channel(HIDDEN) ;
281308

282309
UNEXPECTED_CHAR : . ;

0 commit comments

Comments
 (0)