Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

package traindb.catalog;

import java.nio.file.Path;
import java.util.Collection;
import java.util.List;
import traindb.catalog.pm.MModel;
Expand Down Expand Up @@ -44,12 +43,13 @@ MModel trainModel(

Collection<MModel> getModels() throws CatalogException;

Collection<MModel> getInferenceModels(String baseSchema, String baseTable)
throws CatalogException;

boolean modelExists(String name);

MModel getModel(String name);

Path getModelPath(String modeltypeName, String modelName);

MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio)
throws CatalogException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package traindb.catalog;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collection;
import java.util.List;
import javax.jdo.PersistenceManager;
Expand All @@ -25,7 +23,6 @@
import traindb.catalog.pm.MModel;
import traindb.catalog.pm.MModeltype;
import traindb.catalog.pm.MSynopsis;
import traindb.common.TrainDBConfiguration;
import traindb.common.TrainDBLogger;

public final class JDOCatalogContext implements CatalogContext {
Expand Down Expand Up @@ -139,6 +136,20 @@ public Collection<MModel> getModels() throws CatalogException {
}
}

@Override
public Collection<MModel> getInferenceModels(String baseSchema, String baseTable)
throws CatalogException {
try {
Query query = pm.newQuery(MModel.class);
query.setFilter(
"schemaName == baseSchema && tableName == baseTable && modeltype.type == \"INFERENCE\"");
query.declareParameters("String baseSchema, String baseTable");
return (List<MModel>) query.execute(baseSchema, baseTable);
} catch (RuntimeException e) {
throw new CatalogException("failed to get models", e);
}
}

@Override
public boolean modelExists(String name) {
return getModel(name) != null;
Expand All @@ -157,12 +168,6 @@ public boolean modelExists(String name) {
return null;
}

@Override
public Path getModelPath(String modeltypeName, String modelName) {
return Paths.get(TrainDBConfiguration.getTrainDBPrefixPath(), "models",
modeltypeName, modelName);
}

@Override
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
@Nullable Double ratio) throws CatalogException {
Expand Down
4 changes: 4 additions & 0 deletions traindb-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ limitations under the License.
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
</dependency>
<dependency>
<groupId>net.sf.opencsv</groupId>
<artifactId>opencsv</artifactId>
</dependency>

<!-- Test Support -->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package traindb.adapter.python;

import com.google.common.collect.ImmutableList;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import traindb.catalog.pm.MModel;

/**
* Implementation of {@link Aggregate} relational expression for Python ML.
*/
public class PythonMLAggregate extends Aggregate implements PythonRel {

private static final Set<SqlKind> SUPPORTED_AGGREGATIONS =
EnumSet.of(SqlKind.COUNT, SqlKind.AVG, SqlKind.SUM, SqlKind.STDDEV_POP, SqlKind.VAR_POP,
SqlKind.ANY_VALUE);

final MModel model;

public PythonMLAggregate(RelOptCluster cluster,
RelTraitSet traitSet,
RelNode input,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls,
MModel model) throws InvalidRelException {
super(cluster, traitSet, ImmutableList.of(), input, groupSet, groupSets, aggCalls);
this.model = model;

assert this.groupSets.size() == 1 : "Grouping sets not supported";

for (AggregateCall aggCall : aggCalls) {
final SqlKind kind = aggCall.getAggregation().getKind();
if (!SUPPORTED_AGGREGATIONS.contains(kind)) {
final String message = String.format(Locale.ROOT,
"Aggregation %s not supported (use one of %s)", kind, SUPPORTED_AGGREGATIONS);
throw new InvalidRelException(message);
}
}

if (getGroupType() != Group.SIMPLE) {
final String message = String.format(Locale.ROOT, "Only %s grouping is supported. "
+ "Yours is %s", Group.SIMPLE, getGroupType());
throw new InvalidRelException(message);
}
}

@Override
public Aggregate copy(RelTraitSet traitSet, RelNode input,
ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
try {
return new PythonMLAggregate(getCluster(), traitSet, input,
groupSet, groupSets, aggCalls, model);
} catch (InvalidRelException e) {
throw new AssertionError(e);
}
}

@Override
public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner,
RelMetadataQuery mq) {
return super.computeSelfCost(planner, mq).multiplyBy(0.1);
}

@Override
public void register(RelOptPlanner planner) {
planner.addRule(PythonToEnumerableConverterRule.INSTANCE);
}

@Override
public Result implement() {
Result result = ((PythonRel) getInput()).implement();
return result;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package traindb.adapter.python;

import au.com.bytecode.opencsv.CSVReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.calcite.linq4j.AbstractEnumerable;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.linq4j.Linq4j;
import org.apache.calcite.sql.type.SqlTypeName;
import org.checkerframework.checker.nullness.qual.Nullable;

public class PythonMLAggregateEnumerable extends AbstractEnumerable<Object[]> {

public final String csvResultPath;
public final Map<String, SqlTypeName> fields;

public PythonMLAggregateEnumerable(String csvResultPath, Map<String, SqlTypeName> fields) {
this.csvResultPath = csvResultPath;
this.fields = fields;
}

@Override
public Enumerator<Object[]> enumerator() {
return Linq4j.enumerator(inferResult());
}

private List<Object[]> inferResult() {
try (CSVReader csvReader = new CSVReader(new FileReader(csvResultPath))) {
List<Object[]> result = new ArrayList<>();
String[] line;
while ((line = csvReader.readNext()) != null) {
Object[] row = new Object[fields.size()];
int i = 0;
for (Map.Entry<String, SqlTypeName> field : fields.entrySet()) {
row[i] = convert(field.getValue(), line[i]);
i++;
}
result.add(row);
}
return result;
} catch (IOException exception) {
return Collections.emptyList();
}
}

private @Nullable Object convert(SqlTypeName fieldType, @Nullable String string) {
switch (fieldType) {
case BOOLEAN:
if (string.length() == 0) {
return null;
}
return Boolean.parseBoolean(string);
case TINYINT:
if (string.length() == 0) {
return null;
}
return Byte.parseByte(string);
case SMALLINT:
if (string.length() == 0) {
return null;
}
return Short.parseShort(string);
case INTEGER:
if (string.length() == 0) {
return null;
}
return (int) Double.parseDouble(string);
case BIGINT:
if (string.length() == 0) {
return null;
}
return (long) Double.parseDouble(string);
case FLOAT:
if (string.length() == 0) {
return null;
}
return Float.parseFloat(string);
case DOUBLE:
case DECIMAL:
if (string.length() == 0) {
return null;
}
return Double.parseDouble(string);
case DATE:
case TIME:
case TIMESTAMP:
case VARCHAR:
default:
return string;
}
}

}
Loading