Skip to content

Commit 728f9b1

Browse files
committed
Implement export model and metadata
1 parent debf80d commit 728f9b1

File tree

11 files changed

+312
-8
lines changed

11 files changed

+312
-8
lines changed

traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import com.google.common.collect.ImmutableMap;
1818
import java.sql.Timestamp;
19-
import java.util.ArrayList;
2019
import java.util.Collection;
2120
import java.util.Comparator;
2221
import java.util.List;
@@ -117,14 +116,15 @@ public MModel trainModel(
117116
String modeltypeName, String modelName, String schemaName, String tableName,
118117
List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
119118
@Nullable Long trainedRows, @Nullable String options) throws CatalogException {
119+
MTable mTable;
120120
try {
121121
MSchema mSchema = getSchema(schemaName);
122122
if (mSchema == null) {
123123
mSchema = new MSchema(schemaName);
124124
pm.makePersistent(mSchema);
125125
}
126126

127-
MTable mTable = getTable(schemaName, tableName);
127+
mTable = getTable(schemaName, tableName);
128128
if (mTable == null) {
129129
mTable = new MTable(tableName, "TABLE", mSchema);
130130
pm.makePersistent(mTable);
@@ -146,7 +146,7 @@ public MModel trainModel(
146146
MModeltype mModeltype = getModeltype(modeltypeName);
147147
MModel mModel = new MModel(
148148
mModeltype, modelName, schemaName, tableName, columnNames,
149-
baseTableRows, trainedRows, options == null ? "" : options);
149+
baseTableRows, trainedRows, options == null ? "" : options, mTable);
150150
pm.makePersistent(mModel);
151151

152152
if (mModeltype.getLocation().equals("REMOTE")) {

traindb-catalog/src/main/java/traindb/catalog/pm/MColumn.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package traindb.catalog.pm;
1616

17+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
1718
import javax.jdo.annotations.Column;
1819
import javax.jdo.annotations.IdGeneratorStrategy;
1920
import javax.jdo.annotations.PersistenceCapable;
@@ -22,6 +23,7 @@
2223
import traindb.catalog.CatalogConstants;
2324

2425
@PersistenceCapable
26+
@JsonIgnoreProperties({ "table" })
2527
public final class MColumn {
2628
@PrimaryKey
2729
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)

traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,13 @@ public final class MModel {
6363
@Persistent(mappedBy = "model", dependentElement = "true")
6464
private Collection<MTrainingStatus> training_status;
6565

66+
@Persistent(dependent = "false")
67+
private MTable table;
68+
6669
public MModel(
6770
MModeltype modeltype, String modelName, String schemaName, String tableName,
6871
List<String> columns, @Nullable Long baseTableRows, @Nullable Long trainedRows,
69-
String options) {
72+
String options, MTable table) {
7073
this.modeltype = modeltype;
7174
this.model_name = modelName;
7275
this.schema_name = schemaName;
@@ -75,6 +78,7 @@ public MModel(
7578
this.table_rows = (baseTableRows == null) ? 0 : baseTableRows;
7679
this.trained_rows = (trainedRows == null) ? 0 : trainedRows;
7780
this.model_options = options.getBytes();
81+
this.table = table;
7882
}
7983

8084
public String getModelName() {
@@ -113,6 +117,10 @@ public Collection<MTrainingStatus> trainingStatus() {
113117
return training_status;
114118
}
115119

120+
public MTable getTable() {
121+
return table;
122+
}
123+
116124
public boolean isEnabled() {
117125
if (training_status.isEmpty() || training_status.size() == 0) {
118126
return true;

traindb-catalog/src/main/java/traindb/catalog/pm/MTable.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package traindb.catalog.pm;
1616

17+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
1718
import java.util.ArrayList;
1819
import java.util.Collection;
1920
import javax.jdo.annotations.Column;
@@ -24,6 +25,7 @@
2425
import traindb.catalog.CatalogConstants;
2526

2627
@PersistenceCapable
28+
@JsonIgnoreProperties({ "schema" })
2729
public final class MTable {
2830
@PrimaryKey
2931
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.util;
16+
17+
import java.io.IOException;
18+
import java.io.Writer;
19+
import java.net.URI;
20+
import java.nio.file.FileSystem;
21+
import java.nio.file.FileSystems;
22+
import java.nio.file.Files;
23+
import java.nio.file.Path;
24+
import java.nio.file.Paths;
25+
import java.nio.file.StandardCopyOption;
26+
import java.nio.file.StandardOpenOption;
27+
import java.util.HashMap;
28+
import java.util.Map;
29+
import java.util.zip.ZipEntry;
30+
import java.util.zip.ZipOutputStream;
31+
32+
public final class ZipUtils {
33+
34+
private ZipUtils() {
35+
}
36+
37+
public static void pack(String sourceDirPath, String zipFilePath) throws IOException {
38+
Path zp = Files.createFile(Paths.get(zipFilePath));
39+
try (ZipOutputStream zs = new ZipOutputStream(Files.newOutputStream(zp))) {
40+
Path sp = Paths.get(sourceDirPath);
41+
Files.walk(sp)
42+
.filter(path -> !Files.isDirectory(path))
43+
.forEach(path -> {
44+
ZipEntry zipEntry = new ZipEntry(sp.relativize(path).toString());
45+
try {
46+
zs.putNextEntry(zipEntry);
47+
Files.copy(path, zs);
48+
zs.closeEntry();
49+
} catch (IOException e) {
50+
throw new RuntimeException(e);
51+
}
52+
});
53+
}
54+
}
55+
56+
public static void addFileToZip(Path file, Path zip) throws IOException {
57+
Map<String, String> env = new HashMap<>();
58+
env.put("create", "false");
59+
60+
URI uri = URI.create("jar:file:" + zip.toString());
61+
try (FileSystem fs = FileSystems.newFileSystem(uri, env)) {
62+
Path p = fs.getPath(file.getFileName().toString());
63+
Files.copy(file, p, StandardCopyOption.REPLACE_EXISTING);
64+
}
65+
}
66+
67+
public static void addNewFileFromStringToZip(String newFilename, String contents, Path zip)
68+
throws IOException {
69+
Map<String, String> env = new HashMap<>();
70+
env.put("create", "false");
71+
72+
URI uri = URI.create("jar:file:" + zip.toString());
73+
try (FileSystem fs = FileSystems.newFileSystem(uri, env)) {
74+
Path p = fs.getPath(newFilename);
75+
try (Writer writer = Files.newBufferedWriter(p, StandardOpenOption.CREATE)) {
76+
writer.write(contents);
77+
}
78+
}
79+
}
80+
}

traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.json.simple.JSONObject;
2929
import traindb.catalog.CatalogContext;
3030
import traindb.common.TrainDBConfiguration;
31+
import traindb.common.TrainDBException;
3132
import traindb.jdbc.TrainDBConnectionImpl;
3233
import traindb.schema.TrainDBTable;
3334

@@ -57,6 +58,8 @@ public abstract String infer(String aggregateExpression, String groupByColumn,
5758

5859
public abstract String listHyperparameters(String className, String uri) throws Exception;
5960

61+
public abstract void exportModel(String outputPath) throws Exception;
62+
6063
public boolean checkAvailable(String modelName) throws Exception {
6164
return true;
6265
}

traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ public String listHyperparameters(String className, String uri) throws Exception
231231
return unescapeString(response.toString());
232232
}
233233

234+
@Override
235+
public void exportModel(String outputPath) throws Exception {
236+
throw new TrainDBException("Not supported yet");
237+
}
238+
234239
@Override
235240
public boolean checkAvailable(String modelName) throws Exception {
236241
MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype();

traindb-core/src/main/java/traindb/engine/TrainDBFileModelRunner.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import traindb.common.TrainDBException;
3333
import traindb.jdbc.TrainDBConnectionImpl;
3434
import traindb.schema.TrainDBTable;
35+
import traindb.util.ZipUtils;
3536

3637
public class TrainDBFileModelRunner extends AbstractTrainDBModelRunner {
3738

@@ -151,4 +152,9 @@ public String listHyperparameters(String className, String uri) throws Exception
151152
return hyperparamsInfo;
152153
}
153154

155+
@Override
156+
public void exportModel(String outputPath) throws Exception {
157+
String modelPath = getModelPath().toString();
158+
ZipUtils.pack(modelPath, outputPath);
159+
}
154160
}

traindb-core/src/main/java/traindb/engine/TrainDBPy4JModelRunner.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import traindb.common.TrainDBException;
3838
import traindb.jdbc.TrainDBConnectionImpl;
3939
import traindb.schema.TrainDBTable;
40+
import traindb.util.ZipUtils;
4041

4142
public class TrainDBPy4JModelRunner extends AbstractTrainDBModelRunner {
4243

@@ -109,6 +110,7 @@ public void trainModel(
109110
server.shutdown();
110111
}
111112

113+
@Override
112114
public void generateSynopsis(String outputPath, int rows) throws Exception {
113115
String modelPath = getModelPath().toString();
114116
MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype();
@@ -153,6 +155,12 @@ public String listHyperparameters(String className, String uri) throws Exception
153155
return hyperparamsInfo;
154156
}
155157

158+
@Override
159+
public void exportModel(String outputPath) throws Exception {
160+
String modelPath = getModelPath().toString();
161+
ZipUtils.pack(modelPath, outputPath);
162+
}
163+
156164
private int getAvailablePort() throws Exception {
157165
ServerSocket s;
158166
try {

traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@
1414

1515
package traindb.engine;
1616

17+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
18+
import com.fasterxml.jackson.databind.ObjectMapper;
1719
import com.google.common.collect.ImmutableList;
1820
import com.opencsv.CSVReader;
1921
import com.opencsv.CSVReaderBuilder;
22+
import java.io.File;
23+
import java.io.FileInputStream;
2024
import java.io.FileReader;
25+
import java.nio.file.Path;
26+
import java.nio.file.Paths;
2127
import java.sql.PreparedStatement;
2228
import java.sql.ResultSet;
2329
import java.sql.SQLException;
@@ -27,8 +33,6 @@
2733
import java.util.List;
2834
import java.util.Map;
2935
import org.apache.calcite.sql.SqlDialect;
30-
import org.codehaus.jackson.annotate.JsonIgnoreProperties;
31-
import org.codehaus.jackson.map.ObjectMapper;
3236
import org.json.simple.JSONObject;
3337
import traindb.adapter.TrainDBSqlDialect;
3438
import traindb.catalog.CatalogContext;
@@ -39,14 +43,15 @@
3943
import traindb.catalog.pm.MSynopsis;
4044
import traindb.catalog.pm.MTask;
4145
import traindb.catalog.pm.MTrainingStatus;
42-
import traindb.common.TrainDBConfiguration;
4346
import traindb.common.TrainDBException;
4447
import traindb.common.TrainDBLogger;
48+
import traindb.engine.nio.ByteArray;
4549
import traindb.jdbc.TrainDBConnectionImpl;
4650
import traindb.schema.SchemaManager;
4751
import traindb.schema.TrainDBTable;
4852
import traindb.sql.TrainDBSqlRunner;
4953
import traindb.task.TaskTracer;
54+
import traindb.util.ZipUtils;
5055

5156

5257
public class TrainDBQueryEngine implements TrainDBSqlRunner {
@@ -683,7 +688,68 @@ public void deleteTasks(Integer cnt) throws Exception {
683688

684689
@Override
685690
public TrainDBListResultSet exportModel(String modelName) throws Exception {
686-
throw new TrainDBException("Not supported yet");
691+
T_tracer.startTaskTracer("export model " + modelName);
692+
693+
T_tracer.openTaskTime("find : model");
694+
if (!catalogContext.modelExists(modelName)) {
695+
String msg = "model '" + modelName + "' does not exist";
696+
697+
T_tracer.closeTaskTime(msg);
698+
T_tracer.endTaskTracer();
699+
700+
throw new CatalogException(msg);
701+
}
702+
T_tracer.closeTaskTime("SUCCESS");
703+
704+
T_tracer.openTaskTime("export model");
705+
MModel mModel = catalogContext.getModel(modelName);
706+
MModeltype mModeltype = mModel.getModeltype();
707+
708+
AbstractTrainDBModelRunner runner = createModelRunner(
709+
mModeltype.getModeltypeName(), modelName, mModeltype.getLocation());
710+
711+
if (!mModel.isEnabled()) { // remote model
712+
if (!runner.checkAvailable(modelName)) {
713+
throw new TrainDBException(
714+
"model '" + modelName + "' is not available (training is not finished)");
715+
}
716+
catalogContext.updateTrainingStatus(modelName, "FINISHED");
717+
}
718+
719+
Path outputPath = Paths.get(runner.getModelPath().getParent().toString(), modelName + ".zip");
720+
runner.exportModel(outputPath.toString());
721+
722+
ObjectMapper mapper = new ObjectMapper();
723+
ZipUtils.addNewFileFromStringToZip("export_metadata.json",
724+
mapper.writeValueAsString(mModel), outputPath);
725+
726+
List<String> header = Arrays.asList("export_model");
727+
List<List<Object>> exportModelInfo = new ArrayList<>();
728+
ByteArray byteArray = convertFileToByteArray(new File(outputPath.toString()));
729+
exportModelInfo.add(Arrays.asList(byteArray));
730+
731+
T_tracer.closeTaskTime("SUCCESS");
732+
T_tracer.endTaskTracer();
733+
734+
return new TrainDBListResultSet(header, exportModelInfo);
735+
}
736+
737+
private ByteArray convertFileToByteArray(File file) throws Exception {
738+
byte[] bytes;
739+
FileInputStream inputStream = null;
740+
741+
try {
742+
bytes = new byte[(int) file.length()];
743+
inputStream = new FileInputStream(file);
744+
inputStream.read(bytes);
745+
} catch (Exception e) {
746+
throw e;
747+
} finally {
748+
if (inputStream != null) {
749+
inputStream.close();
750+
}
751+
}
752+
return new ByteArray(bytes);
687753
}
688754

689755
private void checkShowWhereColumns(Map<String, Object> patterns, List<String> columns)

0 commit comments

Comments
 (0)