Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Do

void dropSynopsis(String name) throws CatalogException;

void importSynopsis(String synopsisName, JSONObject exportMetadata) throws CatalogException;

void renameSynopsis(String synopsisName, String newSynopsisName) throws CatalogException;

void enableSynopsis(String synopsisName) throws CatalogException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,36 +277,43 @@ public void updateTrainingStatus(String modelName, String status) throws Catalog
}
}

private void importTable(String schemaName, JSONObject jsonTableMetadata) {
MSchema mSchema = getSchema(schemaName);
if (mSchema == null) {
mSchema = pm.makePersistent(new MSchema(schemaName));
}

String tableName = (String) jsonTableMetadata.get("tableName");
MTable mTable = getTable(schemaName, tableName);
if (mTable == null) {
JSONObject jsonTable = (JSONObject) jsonTableMetadata.get("table");
mTable = pm.makePersistent(
new MTable(tableName, (String) jsonTable.get("tableType"), mSchema));

JSONArray jsonColumns = (JSONArray) jsonTable.get("columns");
for (int i = 0; i < jsonColumns.size(); i++) {
JSONObject jsonColumn = (JSONObject) jsonColumns.get(i);
String columnName = (String) jsonColumn.get("columnName");
int columnType = ((Long) jsonColumn.get("columnType")).intValue();
int precision = ((Long) jsonColumn.get("precision")).intValue();
int scale = ((Long) jsonColumn.get("scale")).intValue();
boolean nullable = (boolean) jsonColumn.get("nullable");
MColumn mColumn = new MColumn(columnName, columnType, precision, scale, nullable, mTable);
pm.makePersistent(mColumn);
}
}
}

@Override
public void importModel(String modeltypeName, String modelName, JSONObject exportMetadata)
throws CatalogException {
try {
String schemaName = (String) exportMetadata.get("schemaName");
MSchema mSchema = getSchema(schemaName);
if (mSchema == null) {
mSchema = pm.makePersistent(new MSchema(schemaName));
}
JSONObject jsonTable = (JSONObject) exportMetadata.get("table");
importTable(schemaName, jsonTable);

String tableName = (String) exportMetadata.get("tableName");
MTable mTable = getTable(schemaName, tableName);
if (mTable == null) {
JSONObject jsonTable = (JSONObject) exportMetadata.get("table");
mTable = pm.makePersistent(
new MTable(tableName, (String) jsonTable.get("tableType"), mSchema));

JSONArray jsonColumns = (JSONArray) jsonTable.get("columns");
for (int i = 0; i < jsonColumns.size(); i++) {
JSONObject jsonColumn = (JSONObject) jsonColumns.get(i);
String columnName = (String) jsonColumn.get("columnName");
int columnType = ((Long) jsonColumn.get("columnType")).intValue();
int precision = ((Long) jsonColumn.get("precision")).intValue();
int scale = ((Long) jsonColumn.get("scale")).intValue();
boolean nullable = (boolean) jsonColumn.get("nullable");
MColumn mColumn = new MColumn(columnName, columnType, precision, scale, nullable, mTable);
pm.makePersistent(mColumn);
}
}

MModeltype mModeltype = getModeltype(modeltypeName);
JSONArray jsonColumnNames = (JSONArray) exportMetadata.get("columnNames");
List<String> columnNames = new ArrayList<>();
Expand Down Expand Up @@ -362,7 +369,8 @@ public void disableModel(String modelName) throws CatalogException {
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
@Nullable Double ratio) throws CatalogException {
try {
MSynopsis mSynopsis = new MSynopsis(synopsisName, rows, ratio, getModel(modelName));
MModel mModel = getModel(modelName);
MSynopsis mSynopsis = new MSynopsis(synopsisName, rows, ratio, mModel, mModel.getTable());
pm.makePersistent(mSynopsis);
return mSynopsis;
} catch (RuntimeException e) {
Expand All @@ -378,8 +386,7 @@ public Collection<MSynopsis> getAllSynopses() throws CatalogException {
@Override
public Collection<MSynopsis> getAllSynopses(String baseSchema, String baseTable)
throws CatalogException {
return getAllSynopses(ImmutableMap.of(
"model.schema_name", baseSchema, "model.table_name", baseTable));
return getAllSynopses(ImmutableMap.of("schema_name", baseSchema, "table_name", baseTable));
}

@Override
Expand Down Expand Up @@ -430,6 +437,34 @@ public void dropSynopsis(String name) throws CatalogException {
}
}

@Override
public void importSynopsis(String synopsisName, JSONObject exportMetadata)
throws CatalogException {
try {
String schemaName = (String) exportMetadata.get("schemaName");
JSONObject jsonTable = (JSONObject) exportMetadata.get("table");
importTable(schemaName, jsonTable);

String tableName = (String) exportMetadata.get("tableName");
MTable mTable = getTable(schemaName, tableName);
JSONArray jsonColumnNames = (JSONArray) exportMetadata.get("columnNames");
List<String> columnNames = new ArrayList<>();
for (int i = 0; i < jsonColumnNames.size(); i++) {
columnNames.add((String) jsonColumnNames.get(i));
}

Integer rows = ((Long) exportMetadata.get("rows")).intValue();
Double ratio = (Double) exportMetadata.get("ratio");

MSynopsis mSynopsis = new MSynopsis(synopsisName, rows, ratio, "-", schemaName, tableName,
columnNames,mTable);
pm.makePersistent(mSynopsis);
} catch (RuntimeException e) {
throw new CatalogException("failed to import synopsis '" + synopsisName + "'", e);
}

}

@Override
public void renameSynopsis(String synopsisName, String newSynopsisName) throws CatalogException {
try {
Expand Down
74 changes: 72 additions & 2 deletions traindb-catalog/src/main/java/traindb/catalog/pm/MSynopsis.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package traindb.catalog.pm;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.util.List;
import javax.jdo.annotations.Column;
import javax.jdo.annotations.IdGeneratorStrategy;
import javax.jdo.annotations.PersistenceCapable;
Expand All @@ -23,6 +25,7 @@
import traindb.catalog.CatalogConstants;

@PersistenceCapable
@JsonIgnoreProperties({ "model" })
public final class MSynopsis {
@PrimaryKey
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
Expand All @@ -33,6 +36,21 @@ public final class MSynopsis {
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
private String synopsis_name;

@Persistent
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
private String model_name;

@Persistent
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
private String schema_name;

@Persistent
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
private String table_name;

@Persistent
private List<String> columns;

@Persistent
private int rows;

Expand All @@ -46,12 +64,28 @@ public final class MSynopsis {
@Persistent(dependent = "false")
private MModel model;

public MSynopsis(String name, Integer rows, Double ratio, MModel model) {
@Persistent(dependent = "false")
private MTable table;

public MSynopsis(String name, Integer rows, Double ratio, MModel model, MTable table) {
this(name, rows, ratio, model.getModelName(), model.getSchemaName(), model.getTableName(),
model.getColumnNames(), table);
this.model = model;
}

public MSynopsis(String name, Integer rows, Double ratio, String modelName, String schemaName,
String tableName, List<String> columns, MTable table) {
this.synopsis_name = name;
this.rows = rows;
this.ratio = (ratio == null) ? 0 : ratio;
this.synopsis_status = "ENABLED"; // initial status
this.model = model;
this.model_name = modelName;
this.schema_name = schemaName;
this.table_name = tableName;
this.columns = columns;
this.table = table;

this.model = null;
}

public String getSynopsisName() {
Expand All @@ -70,6 +104,26 @@ public MModel getModel() {
return model;
}

public MTable getTable() {
return table;
}

public String getModelName() {
return model_name;
}

public String getSchemaName() {
return schema_name;
}

public String getTableName() {
return table_name;
}

public List<String> getColumnNames() {
return columns;
}

public String getSynopsisStatus() {
return synopsis_status;
}
Expand All @@ -89,4 +143,20 @@ public void disableSynopsis() {
public boolean isEnabled() {
return synopsis_status.equals("ENABLED");
}

public void setModelName(String modelName) {
this.model_name = modelName;
}

public void setSchemaName(String schemaName) {
this.schema_name = schemaName;
}

public void setTableName(String tableName) {
this.table_name = tableName;
}

public void setColumnNames(List<String> columns) {
this.columns = columns;
}
}
14 changes: 14 additions & 0 deletions traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ traindbStmts
| deleteTasks
| exportModel
| importModel
| exportSynopsis
| importSynopsis
| incrementalQuery
;

Expand Down Expand Up @@ -154,6 +156,18 @@ modelBinaryString
: BINARY_STRING_LITERAL
;

exportSynopsis
: K_EXPORT K_SYNOPSIS synopsisName
;

importSynopsis
: K_IMPORT K_SYNOPSIS synopsisName K_FROM synopsisBinaryString
;

synopsisBinaryString
: BINARY_STRING_LITERAL
;

showStmt
: K_SHOW showTargets showWhereClause?
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

package traindb.engine;

import com.opencsv.CSVWriter;
import com.opencsv.ResultSetHelperService;
import java.io.FileWriter;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.sql.DatabaseMetaData;
Expand Down Expand Up @@ -91,9 +94,9 @@ private String getTableSampleClause(float samplePercent) throws TrainDBException
try {
String connDbms = conn.getMetaData().getURL().split(":")[1];
if (connDbms.equals("mysql")) {
return "WHERE rand() < " + (samplePercent / 100.0);
return " WHERE rand() < " + (samplePercent / 100.0);
} else if (connDbms.equals("postgresql")) {
return "TABLESAMPLE BERNOULLI(" + samplePercent + ")";
return " TABLESAMPLE BERNOULLI(" + samplePercent + ")";
}
} catch (SQLException e) {
// ignore
Expand All @@ -104,6 +107,15 @@ private String getTableSampleClause(float samplePercent) throws TrainDBException
protected String buildSelectTrainingDataQuery(
String schemaName, String tableName, List<String> columnNames, float samplePercent,
RelDataType relDataType) throws TrainDBException {
String query = buildExportTableQuery(schemaName, tableName, columnNames, relDataType);
if (samplePercent > 0 && samplePercent < 100) {
query = query + getTableSampleClause(samplePercent);
}
return query;
}

public static String buildExportTableQuery(String schemaName, String tableName,
List<String> columnNames, RelDataType relDataType) {
StringBuilder sb = new StringBuilder();
sb.append("SELECT ");
for (int i = 0; i < columnNames.size(); i++) {
Expand All @@ -121,13 +133,21 @@ protected String buildSelectTrainingDataQuery(
sb.append(schemaName);
sb.append(".");
sb.append(tableName);
if (samplePercent > 0 && samplePercent < 100) {
sb.append(" ").append(getTableSampleClause(samplePercent));
}

return sb.toString();
}

public static void writeResultSetToCsv(ResultSet rs, String filePath) throws Exception {
FileWriter datafileWriter = new FileWriter(filePath);
CSVWriter csvWriter = new CSVWriter(datafileWriter, ',');
ResultSetHelperService resultSetHelperService = new ResultSetHelperService();
resultSetHelperService.setDateFormat("yyyy-MM-dd");
resultSetHelperService.setDateTimeFormat("yyyy-MM-dd HH:MI:SS");
csvWriter.setResultService(resultSetHelperService);
csvWriter.writeAll(rs, true);
csvWriter.close();
}

protected JSONObject buildTableMetadata(
String schemaName, String tableName, List<String> columnNames,
Map<String, Object> trainOptions, RelDataType relDataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package traindb.engine;

import com.opencsv.CSVWriter;
import com.opencsv.ResultSetHelperService;
import java.io.File;
import java.io.FileWriter;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -60,7 +58,7 @@ public void trainModel(TrainDBTable table, List<String> columnNames, float sampl
Path modelPath = getModelPath();
Files.createDirectories(modelPath);
String outputPath = modelPath.toString();
String metadataFilename = outputPath + "/metadata.json";
String metadataFilename = Paths.get(outputPath, "metadata.json").toString();
FileWriter fileWriter = new FileWriter(metadataFilename);
fileWriter.write(tableMetadata.toJSONString());
fileWriter.flush();
Expand All @@ -69,15 +67,8 @@ public void trainModel(TrainDBTable table, List<String> columnNames, float sampl
String sql = buildSelectTrainingDataQuery(schemaName, tableName, columnNames, samplePercent,
table.getRowType(typeFactory));
ResultSet trainingData = conn.executeQueryInternal(sql);
String dataFilename = outputPath + "/data.csv";
FileWriter datafileWriter = new FileWriter(dataFilename);
CSVWriter csvWriter = new CSVWriter(datafileWriter, ',');
ResultSetHelperService resultSetHelperService = new ResultSetHelperService();
resultSetHelperService.setDateFormat("yyyy-MM-dd");
resultSetHelperService.setDateTimeFormat("yyyy-MM-dd HH:MI:SS");
csvWriter.setResultService(resultSetHelperService);
csvWriter.writeAll(trainingData, true);
csvWriter.close();
String dataFilename = Paths.get(outputPath, "data.csv").toString();
writeResultSetToCsv(trainingData, dataFilename);
trainingData.close();

MModeltype mModeltype = catalogContext.getModeltype(modeltypeName);
Expand Down
Loading