Skip to content

Commit 0066c81

Browse files
taewhiChoonseoPark
authored andcommitted
Test : update test cases for basic spatial analysis functions (sort results by id)
2 parents 5a0f797 + ce4027b commit 0066c81

File tree

16 files changed

+1006
-490
lines changed

16 files changed

+1006
-490
lines changed

traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import traindb.catalog.pm.MSynopsis;
2727
import traindb.catalog.pm.MTable;
2828
import traindb.catalog.pm.MTask;
29+
import traindb.catalog.pm.MTrainingStatus;
2930

3031
public interface CatalogContext {
3132

@@ -62,6 +63,11 @@ Collection<MModel> getInferenceModels(String baseSchema, String baseTable)
6263

6364
MModel getModel(String name);
6465

66+
Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns)
67+
throws CatalogException;
68+
69+
void updateTrainingStatus(String modelName, String status) throws CatalogException;
70+
6571
/* Synopsis */
6672
MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio)
6773
throws CatalogException;

traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
package traindb.catalog;
1616

1717
import com.google.common.collect.ImmutableMap;
18+
import java.sql.Timestamp;
19+
import java.util.ArrayList;
1820
import java.util.Collection;
21+
import java.util.Comparator;
1922
import java.util.List;
2023
import java.util.Map;
2124
import javax.jdo.PersistenceManager;
@@ -32,6 +35,7 @@
3235
import traindb.catalog.pm.MSynopsis;
3336
import traindb.catalog.pm.MTable;
3437
import traindb.catalog.pm.MTask;
38+
import traindb.catalog.pm.MTrainingStatus;
3539
import traindb.common.TrainDBLogger;
3640

3741
public final class JDOCatalogContext implements CatalogContext {
@@ -139,10 +143,18 @@ public MModel trainModel(
139143
}
140144
}
141145

146+
MModeltype mModeltype = getModeltype(modeltypeName);
142147
MModel mModel = new MModel(
143-
getModeltype(modeltypeName), modelName, schemaName, tableName, columnNames,
148+
mModeltype, modelName, schemaName, tableName, columnNames,
144149
baseTableRows, trainedRows, options == null ? "" : options);
145150
pm.makePersistent(mModel);
151+
152+
if (mModeltype.getLocation().equals("REMOTE")) {
153+
MTrainingStatus mTrainingStatus = new MTrainingStatus(modelName, "TRAINING",
154+
new Timestamp(System.currentTimeMillis()), mModel);
155+
pm.makePersistent(mTrainingStatus);
156+
}
157+
146158
return mModel;
147159
} catch (RuntimeException e) {
148160
throw new CatalogException("failed to train model '" + modelName + "'", e);
@@ -178,6 +190,13 @@ public void dropModel(String name) throws CatalogException {
178190
tx.commit();
179191
}
180192

193+
Collection<MTrainingStatus> trainingStatus =
194+
getTrainingStatus(ImmutableMap.of("model_name", name));
195+
if (trainingStatus != null && trainingStatus.size() > 0) {
196+
tx.begin();
197+
pm.deletePersistentAll(trainingStatus);
198+
tx.commit();
199+
}
181200
} catch (RuntimeException e) {
182201
throw new CatalogException("failed to drop model '" + name + "'", e);
183202
} finally {
@@ -228,6 +247,33 @@ public boolean modelExists(String name) {
228247
return null;
229248
}
230249

250+
@Override
251+
public Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns)
252+
throws CatalogException {
253+
try {
254+
Query query = pm.newQuery(MTrainingStatus.class);
255+
setFilterPatterns(query, filterPatterns);
256+
return (List<MTrainingStatus>) query.execute();
257+
} catch (RuntimeException e) {
258+
throw new CatalogException("failed to get training status", e);
259+
}
260+
}
261+
262+
@Override
263+
public void updateTrainingStatus(String modelName, String status) throws CatalogException {
264+
try {
265+
Query query = pm.newQuery(MTrainingStatus.class);
266+
setFilterPatterns(query, ImmutableMap.of("model_name", modelName));
267+
List<MTrainingStatus> trainingStatus = (List<MTrainingStatus>) query.execute();
268+
Comparator<MTrainingStatus> comparator = Comparator.comparing(MTrainingStatus::getStartTime);
269+
MTrainingStatus latestStatus = trainingStatus.stream().max(comparator).get();
270+
latestStatus.setTrainingStatus(status);
271+
pm.makePersistent(latestStatus);
272+
} catch (RuntimeException e) {
273+
throw new CatalogException("failed to get training status", e);
274+
}
275+
}
276+
231277
@Override
232278
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
233279
@Nullable Double ratio) throws CatalogException {

traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
package traindb.catalog.pm;
1616

17+
import java.util.Collection;
18+
import java.util.Comparator;
1719
import java.util.List;
1820
import javax.jdo.annotations.Column;
1921
import javax.jdo.annotations.IdGeneratorStrategy;
@@ -58,6 +60,9 @@ public final class MModel {
5860
@Persistent
5961
private byte[] model_options;
6062

63+
@Persistent(mappedBy = "model", dependentElement = "true")
64+
private Collection<MTrainingStatus> training_status;
65+
6166
public MModel(
6267
MModeltype modeltype, String modelName, String schemaName, String tableName,
6368
List<String> columns, @Nullable Long baseTableRows, @Nullable Long trainedRows,
@@ -103,4 +108,17 @@ public long getTrainedRows() {
103108
public String getModelOptions() {
104109
return new String(model_options);
105110
}
111+
112+
public Collection<MTrainingStatus> trainingStatus() {
113+
return training_status;
114+
}
115+
116+
public boolean isEnabled() {
117+
if (training_status.isEmpty() || training_status.size() == 0) {
118+
return true;
119+
}
120+
Comparator<MTrainingStatus> comparator = Comparator.comparing(MTrainingStatus::getStartTime);
121+
MTrainingStatus latestStatus = training_status.stream().max(comparator).get();
122+
return latestStatus.getTrainingStatus().equals("FINISHED");
123+
}
106124
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.catalog.pm;
16+
17+
import java.sql.Timestamp;
18+
import javax.jdo.annotations.Column;
19+
import javax.jdo.annotations.IdGeneratorStrategy;
20+
import javax.jdo.annotations.Index;
21+
import javax.jdo.annotations.PersistenceCapable;
22+
import javax.jdo.annotations.Persistent;
23+
import javax.jdo.annotations.PrimaryKey;
24+
import traindb.catalog.CatalogConstants;
25+
26+
@PersistenceCapable
27+
@Index(name="TRAINING_STATUS_IDX", members={"model_name", "start_time"})
28+
public final class MTrainingStatus {
29+
@PrimaryKey
30+
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
31+
private long id;
32+
33+
@Persistent
34+
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
35+
private String model_name;
36+
37+
@Persistent
38+
private Timestamp start_time;
39+
40+
@Persistent
41+
@Column(length = 9)
42+
// Status: TRAINING, FINISHED
43+
private String training_status;
44+
45+
@Persistent(dependent = "false")
46+
private MModel model;
47+
48+
public MTrainingStatus(String modelName, String status, Timestamp startTime, MModel model) {
49+
this.model_name = modelName;
50+
this.training_status = status;
51+
this.start_time = startTime;
52+
this.model = model;
53+
}
54+
55+
public String getModelName() {
56+
return model_name;
57+
}
58+
59+
public Timestamp getStartTime() {
60+
return start_time;
61+
}
62+
63+
public String getTrainingStatus() {
64+
return training_status;
65+
}
66+
67+
public MModel getModel() {
68+
return model;
69+
}
70+
71+
public void setTrainingStatus(String status) {
72+
this.training_status = status;
73+
}
74+
}

traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ showTargets
125125
| K_SCHEMAS
126126
| K_TABLES
127127
| K_HYPERPARAMETERS
128+
| K_TRAININGS
128129
| K_QUERYLOGS
129130
| K_TASKS
130131
;
@@ -241,6 +242,7 @@ K_SYNOPSIS : S Y N O P S I S ;
241242
K_TABLES : T A B L E S ;
242243
K_TASKS : T A S K S ;
243244
K_TRAIN : T R A I N ;
245+
K_TRAININGS : T R A I N I N G S;
244246
K_USE : U S E ;
245247
K_WHERE : W H E R E ;
246248

traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,30 @@ public abstract String infer(String aggregateExpression, String groupByColumn,
5353

5454
public abstract String listHyperparameters(String className, String uri) throws Exception;
5555

56+
public boolean checkAvailable(String modelName) throws Exception {
57+
return true;
58+
}
59+
5660
public Path getModelPath() {
5761
return Paths.get(TrainDBConfiguration.getTrainDBPrefixPath(), "models",
5862
modeltypeName, modelName);
5963
}
6064

65+
public static AbstractTrainDBModelRunner createModelRunner(
66+
TrainDBConnectionImpl conn, CatalogContext catalogContext, TrainDBConfiguration config,
67+
String modeltypeName, String modelName, String location) {
68+
if (location.equals("REMOTE")) {
69+
return new TrainDBFastApiModelRunner(conn, catalogContext, modeltypeName, modelName);
70+
}
71+
// location.equals("LOCAL")
72+
if (config.getModelRunner().equals("py4j")) {
73+
return new TrainDBPy4JModelRunner(conn, catalogContext, modeltypeName, modelName);
74+
}
75+
76+
return new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName);
77+
}
78+
79+
6180
protected String buildSelectTrainingDataQuery(String schemaName, String tableName,
6281
List<String> columnNames) {
6382
StringBuilder sb = new StringBuilder();

0 commit comments

Comments
 (0)