Skip to content
4 changes: 4 additions & 0 deletions traindb-catalog/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ limitations under the License.
<groupId>org.xerial</groupId>
<artifactId>sqlite-jdbc</artifactId>
</dependency>
<dependency>
<groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;
import org.apache.calcite.rel.type.RelDataType;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.json.simple.JSONObject;
import traindb.catalog.pm.MModel;
import traindb.catalog.pm.MModeltype;
import traindb.catalog.pm.MQueryLog;
Expand Down Expand Up @@ -68,6 +69,9 @@ Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns

void updateTrainingStatus(String modelName, String status) throws CatalogException;

void importModel(String modeltypeName, String modelName, JSONObject exportMetadata)
throws CatalogException;

/* Synopsis */
MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio)
throws CatalogException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import traindb.catalog.pm.MColumn;
import traindb.catalog.pm.MModel;
import traindb.catalog.pm.MModeltype;
Expand Down Expand Up @@ -117,14 +119,15 @@ public MModel trainModel(
String modeltypeName, String modelName, String schemaName, String tableName,
List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
@Nullable Long trainedRows, @Nullable String options) throws CatalogException {
MTable mTable;
try {
MSchema mSchema = getSchema(schemaName);
if (mSchema == null) {
mSchema = new MSchema(schemaName);
pm.makePersistent(mSchema);
}

MTable mTable = getTable(schemaName, tableName);
mTable = getTable(schemaName, tableName);
if (mTable == null) {
mTable = new MTable(tableName, "TABLE", mSchema);
pm.makePersistent(mTable);
Expand All @@ -146,7 +149,7 @@ public MModel trainModel(
MModeltype mModeltype = getModeltype(modeltypeName);
MModel mModel = new MModel(
mModeltype, modelName, schemaName, tableName, columnNames,
baseTableRows, trainedRows, options == null ? "" : options);
baseTableRows, trainedRows, options == null ? "" : options, mTable);
pm.makePersistent(mModel);

if (mModeltype.getLocation().equals("REMOTE")) {
Expand Down Expand Up @@ -274,6 +277,54 @@ public void updateTrainingStatus(String modelName, String status) throws Catalog
}
}

@Override
public void importModel(String modeltypeName, String modelName, JSONObject exportMetadata)
throws CatalogException {
try {
String schemaName = (String) exportMetadata.get("schemaName");
MSchema mSchema = getSchema(schemaName);
if (mSchema == null) {
mSchema = pm.makePersistent(new MSchema(schemaName));
}

String tableName = (String) exportMetadata.get("tableName");
MTable mTable = getTable(schemaName, tableName);
if (mTable == null) {
JSONObject jsonTable = (JSONObject) exportMetadata.get("table");
mTable = pm.makePersistent(
new MTable(tableName, (String) jsonTable.get("tableType"), mSchema));

JSONArray jsonColumns = (JSONArray) jsonTable.get("columns");
for (int i = 0; i < jsonColumns.size(); i++) {
JSONObject jsonColumn = (JSONObject) jsonColumns.get(i);
String columnName = (String) jsonColumn.get("columnName");
int columnType = ((Long) jsonColumn.get("columnType")).intValue();
int precision = ((Long) jsonColumn.get("precision")).intValue();
int scale = ((Long) jsonColumn.get("scale")).intValue();
boolean nullable = (boolean) jsonColumn.get("nullable");
MColumn mColumn = new MColumn(columnName, columnType, precision, scale, nullable, mTable);
pm.makePersistent(mColumn);
}
}

MModeltype mModeltype = getModeltype(modeltypeName);
JSONArray jsonColumnNames = (JSONArray) exportMetadata.get("columnNames");
List<String> columnNames = new ArrayList<>();
for (int i = 0; i < jsonColumnNames.size(); i++) {
columnNames.add((String) jsonColumnNames.get(i));
}
String options = (String) exportMetadata.get("modelOptions");
Long baseTableRows = (Long) exportMetadata.get("tableRows");
Long trainedRows = (Long) exportMetadata.get("trainedRows");
MModel mModel = new MModel(
mModeltype, modelName, schemaName, tableName, columnNames,
baseTableRows, trainedRows, options == null ? "" : options, mTable);
pm.makePersistent(mModel);
} catch (RuntimeException e) {
throw new CatalogException("failed to import model '" + modelName + "'", e);
}
}

@Override
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
@Nullable Double ratio) throws CatalogException {
Expand Down
2 changes: 2 additions & 0 deletions traindb-catalog/src/main/java/traindb/catalog/pm/MColumn.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package traindb.catalog.pm;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import javax.jdo.annotations.Column;
import javax.jdo.annotations.IdGeneratorStrategy;
import javax.jdo.annotations.PersistenceCapable;
Expand All @@ -22,6 +23,7 @@
import traindb.catalog.CatalogConstants;

@PersistenceCapable
@JsonIgnoreProperties({ "table" })
public final class MColumn {
@PrimaryKey
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
Expand Down
10 changes: 9 additions & 1 deletion traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ public final class MModel {
@Persistent(mappedBy = "model", dependentElement = "true")
private Collection<MTrainingStatus> training_status;

@Persistent(dependent = "false")
private MTable table;

public MModel(
MModeltype modeltype, String modelName, String schemaName, String tableName,
List<String> columns, @Nullable Long baseTableRows, @Nullable Long trainedRows,
String options) {
String options, MTable table) {
this.modeltype = modeltype;
this.model_name = modelName;
this.schema_name = schemaName;
Expand All @@ -75,6 +78,7 @@ public MModel(
this.table_rows = (baseTableRows == null) ? 0 : baseTableRows;
this.trained_rows = (trainedRows == null) ? 0 : trainedRows;
this.model_options = options.getBytes();
this.table = table;
}

public String getModelName() {
Expand Down Expand Up @@ -113,6 +117,10 @@ public Collection<MTrainingStatus> trainingStatus() {
return training_status;
}

public MTable getTable() {
return table;
}

public boolean isEnabled() {
if (training_status.isEmpty() || training_status.size() == 0) {
return true;
Expand Down
2 changes: 2 additions & 0 deletions traindb-catalog/src/main/java/traindb/catalog/pm/MTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package traindb.catalog.pm;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.util.ArrayList;
import java.util.Collection;
import javax.jdo.annotations.Column;
Expand All @@ -24,6 +25,7 @@
import traindb.catalog.CatalogConstants;

@PersistenceCapable
@JsonIgnoreProperties({ "schema" })
public final class MTable {
@PrimaryKey
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
Expand Down
4 changes: 4 additions & 0 deletions traindb-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ limitations under the License.
<groupId>org.apache.calcite</groupId>
<artifactId>calcite-core</artifactId>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</dependency>
<dependency>
<groupId>sqlline</groupId>
<artifactId>sqlline</artifactId>
Expand Down
135 changes: 135 additions & 0 deletions traindb-common/src/main/java/traindb/util/ZipUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package traindb.util;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.Writer;
import java.net.URI;
import java.nio.file.FileSystem;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.StandardOpenOption;
import java.util.HashMap;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import org.apache.commons.io.IOUtils;

public final class ZipUtils {

private ZipUtils() {
}

public static void pack(String sourceDirPath, String zipFilePath) throws IOException {
Path zp = Files.createFile(Paths.get(zipFilePath));
try (ZipOutputStream zs = new ZipOutputStream(Files.newOutputStream(zp))) {
Path sp = Paths.get(sourceDirPath);
Files.walk(sp)
.filter(path -> !Files.isDirectory(path))
.forEach(path -> {
ZipEntry zipEntry = new ZipEntry(sp.relativize(path).toString());
try {
zs.putNextEntry(zipEntry);
Files.copy(path, zs);
zs.closeEntry();
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
}

public static void addFileToZip(Path file, Path zip) throws IOException {
Map<String, String> env = new HashMap<>();
env.put("create", "false");

URI uri = URI.create("jar:file:" + zip.toString());
try (FileSystem fs = FileSystems.newFileSystem(uri, env)) {
Path p = fs.getPath(file.getFileName().toString());
Files.copy(file, p, StandardCopyOption.REPLACE_EXISTING);
}
}

public static void addNewFileFromStringToZip(String newFilename, String contents, Path zip)
throws IOException {
Map<String, String> env = new HashMap<>();
env.put("create", "false");

URI uri = URI.create("jar:file:" + zip.toString());
try (FileSystem fs = FileSystems.newFileSystem(uri, env)) {
Path p = fs.getPath(newFilename);
try (Writer writer = Files.newBufferedWriter(p, StandardOpenOption.CREATE)) {
writer.write(contents);
}
}
}

public static byte[] extractZipEntry(byte[] content, String filename) throws IOException {
ZipInputStream zis = null;
byte[] bytes = null;
try {
zis = new ZipInputStream(new ByteArrayInputStream(content));
ZipEntry zipEntry;
while ((zipEntry = zis.getNextEntry()) != null) {
if (zipEntry.getName().equals(filename)) {
bytes = IOUtils.readFully(zis, (int) zipEntry.getSize());
break;
}
}
} finally {
if (zis != null) {
zis.close();
}
}
return bytes;
}

public static void unpack(byte[] content, String outputPath) throws IOException {
ZipInputStream zis = null;

try {
File dir = new File(outputPath);
if (!dir.exists()) {
dir.mkdirs();
}
byte[] buffer = new byte[8192];
zis = new ZipInputStream(new ByteArrayInputStream(content));
ZipEntry zipEntry;
while ((zipEntry = zis.getNextEntry()) != null) {
String fileName = zipEntry.getName();
File newFile = new File(outputPath + File.separator + fileName);
new File(newFile.getParent()).mkdirs(); //create directories for sub directories in zip
FileOutputStream fos = new FileOutputStream(newFile);
int len;
while ((len = zis.read(buffer)) > 0) {
fos.write(buffer, 0, len);
}
fos.close();
zis.closeEntry();
}
} finally {
if (zis != null) {
zis.close();
}
}
}
}
27 changes: 27 additions & 0 deletions traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ traindbStmts
| bypassDdlStmt
| deleteQueryLogs
| deleteTasks
| exportModel
| importModel
;

createModeltype
Expand Down Expand Up @@ -114,6 +116,18 @@ optionValue
| NUMERIC_LITERAL
;

exportModel
: K_EXPORT K_MODEL modelName
;

importModel
: K_IMPORT K_MODEL modelName K_FROM modelBinaryString
;

modelBinaryString
: BINARY_STRING_LITERAL
;

showStmt
: K_SHOW showTargets showWhereClause?
;
Expand Down Expand Up @@ -219,9 +233,11 @@ K_DELETE : D E L E T E ;
K_DESC : D E S C ;
K_DESCRIBE : D E S C R I B E ;
K_DROP : D R O P ;
K_EXPORT : E X P O R T ;
K_FOR : F O R ;
K_FROM : F R O M ;
K_HYPERPARAMETERS : H Y P E R P A R A M E T E R S ;
K_IMPORT : I M P O R T ;
K_IN : I N ;
K_INFERENCE : I N F E R E N C E ;
K_LIKE : L I K E ;
Expand Down Expand Up @@ -277,6 +293,17 @@ STRING_LITERAL
}
;

BINARY_STRING_LITERAL
: 'x' '\'' HEXDIGIT* '\''
{
setText(getText().substring(2, getText().length() - 1));
}
;

HEXDIGIT
: [a-fA-F0-9]
;

WHITESPACES : [ \t\r\n]+ -> channel(HIDDEN) ;

UNEXPECTED_CHAR : . ;
Expand Down
Loading