Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import traindb.catalog.pm.MSynopsis;
import traindb.catalog.pm.MTable;
import traindb.catalog.pm.MTask;
import traindb.catalog.pm.MTrainingStatus;

public interface CatalogContext {

Expand Down Expand Up @@ -62,6 +63,11 @@ Collection<MModel> getInferenceModels(String baseSchema, String baseTable)

MModel getModel(String name);

Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns)
throws CatalogException;

void updateTrainingStatus(String modelName, String status) throws CatalogException;

/* Synopsis */
MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio)
throws CatalogException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package traindb.catalog;

import com.google.common.collect.ImmutableMap;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import javax.jdo.PersistenceManager;
Expand All @@ -32,6 +35,7 @@
import traindb.catalog.pm.MSynopsis;
import traindb.catalog.pm.MTable;
import traindb.catalog.pm.MTask;
import traindb.catalog.pm.MTrainingStatus;
import traindb.common.TrainDBLogger;

public final class JDOCatalogContext implements CatalogContext {
Expand Down Expand Up @@ -139,10 +143,18 @@ public MModel trainModel(
}
}

MModeltype mModeltype = getModeltype(modeltypeName);
MModel mModel = new MModel(
getModeltype(modeltypeName), modelName, schemaName, tableName, columnNames,
mModeltype, modelName, schemaName, tableName, columnNames,
baseTableRows, trainedRows, options == null ? "" : options);
pm.makePersistent(mModel);

if (mModeltype.getLocation().equals("REMOTE")) {
MTrainingStatus mTrainingStatus = new MTrainingStatus(modelName, "TRAINING",
new Timestamp(System.currentTimeMillis()), mModel);
pm.makePersistent(mTrainingStatus);
}

return mModel;
} catch (RuntimeException e) {
throw new CatalogException("failed to train model '" + modelName + "'", e);
Expand Down Expand Up @@ -178,6 +190,13 @@ public void dropModel(String name) throws CatalogException {
tx.commit();
}

Collection<MTrainingStatus> trainingStatus =
getTrainingStatus(ImmutableMap.of("model_name", name));
if (trainingStatus != null && trainingStatus.size() > 0) {
tx.begin();
pm.deletePersistentAll(trainingStatus);
tx.commit();
}
} catch (RuntimeException e) {
throw new CatalogException("failed to drop model '" + name + "'", e);
} finally {
Expand Down Expand Up @@ -228,6 +247,33 @@ public boolean modelExists(String name) {
return null;
}

@Override
public Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns)
throws CatalogException {
try {
Query query = pm.newQuery(MTrainingStatus.class);
setFilterPatterns(query, filterPatterns);
return (List<MTrainingStatus>) query.execute();
} catch (RuntimeException e) {
throw new CatalogException("failed to get training status", e);
}
}

@Override
public void updateTrainingStatus(String modelName, String status) throws CatalogException {
try {
Query query = pm.newQuery(MTrainingStatus.class);
setFilterPatterns(query, ImmutableMap.of("model_name", modelName));
List<MTrainingStatus> trainingStatus = (List<MTrainingStatus>) query.execute();
Comparator<MTrainingStatus> comparator = Comparator.comparing(MTrainingStatus::getStartTime);
MTrainingStatus latestStatus = trainingStatus.stream().max(comparator).get();
latestStatus.setTrainingStatus(status);
pm.makePersistent(latestStatus);
} catch (RuntimeException e) {
throw new CatalogException("failed to get training status", e);
}
}

@Override
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
@Nullable Double ratio) throws CatalogException {
Expand Down
18 changes: 18 additions & 0 deletions traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package traindb.catalog.pm;

import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import javax.jdo.annotations.Column;
import javax.jdo.annotations.IdGeneratorStrategy;
Expand Down Expand Up @@ -58,6 +60,9 @@ public final class MModel {
@Persistent
private byte[] model_options;

@Persistent(mappedBy = "model", dependentElement = "true")
private Collection<MTrainingStatus> training_status;

public MModel(
MModeltype modeltype, String modelName, String schemaName, String tableName,
List<String> columns, @Nullable Long baseTableRows, @Nullable Long trainedRows,
Expand Down Expand Up @@ -103,4 +108,17 @@ public long getTrainedRows() {
public String getModelOptions() {
return new String(model_options);
}

public Collection<MTrainingStatus> trainingStatus() {
return training_status;
}

public boolean isEnabled() {
if (training_status.isEmpty() || training_status.size() == 0) {
return true;
}
Comparator<MTrainingStatus> comparator = Comparator.comparing(MTrainingStatus::getStartTime);
MTrainingStatus latestStatus = training_status.stream().max(comparator).get();
return latestStatus.getTrainingStatus().equals("FINISHED");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package traindb.catalog.pm;

import java.sql.Timestamp;
import javax.jdo.annotations.Column;
import javax.jdo.annotations.IdGeneratorStrategy;
import javax.jdo.annotations.Index;
import javax.jdo.annotations.PersistenceCapable;
import javax.jdo.annotations.Persistent;
import javax.jdo.annotations.PrimaryKey;
import traindb.catalog.CatalogConstants;

@PersistenceCapable
@Index(name="TRAINING_STATUS_IDX", members={"model_name", "start_time"})
public final class MTrainingStatus {
@PrimaryKey
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
private long id;

@Persistent
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
private String model_name;

@Persistent
private Timestamp start_time;

@Persistent
@Column(length = 9)
// Status: TRAINING, FINISHED
private String training_status;

@Persistent(dependent = "false")
private MModel model;

public MTrainingStatus(String modelName, String status, Timestamp startTime, MModel model) {
this.model_name = modelName;
this.training_status = status;
this.start_time = startTime;
this.model = model;
}

public String getModelName() {
return model_name;
}

public Timestamp getStartTime() {
return start_time;
}

public String getTrainingStatus() {
return training_status;
}

public MModel getModel() {
return model;
}

public void setTrainingStatus(String status) {
this.training_status = status;
}
}
2 changes: 2 additions & 0 deletions traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ showTargets
| K_SCHEMAS
| K_TABLES
| K_HYPERPARAMETERS
| K_TRAININGS
| K_QUERYLOGS
| K_TASKS
;
Expand Down Expand Up @@ -241,6 +242,7 @@ K_SYNOPSIS : S Y N O P S I S ;
K_TABLES : T A B L E S ;
K_TASKS : T A S K S ;
K_TRAIN : T R A I N ;
K_TRAININGS : T R A I N I N G S;
K_USE : U S E ;
K_WHERE : W H E R E ;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,30 @@ public abstract String infer(String aggregateExpression, String groupByColumn,

public abstract String listHyperparameters(String className, String uri) throws Exception;

public boolean checkAvailable(String modelName) throws Exception {
return true;
}

public Path getModelPath() {
return Paths.get(TrainDBConfiguration.getTrainDBPrefixPath(), "models",
modeltypeName, modelName);
}

public static AbstractTrainDBModelRunner createModelRunner(
TrainDBConnectionImpl conn, CatalogContext catalogContext, TrainDBConfiguration config,
String modeltypeName, String modelName, String location) {
if (location.equals("REMOTE")) {
return new TrainDBFastApiModelRunner(conn, catalogContext, modeltypeName, modelName);
}
// location.equals("LOCAL")
if (config.getModelRunner().equals("py4j")) {
return new TrainDBPy4JModelRunner(conn, catalogContext, modeltypeName, modelName);
}

return new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName);
}


protected String buildSelectTrainingDataQuery(String schemaName, String tableName,
List<String> columnNames) {
StringBuilder sb = new StringBuilder();
Expand Down
Loading