Skip to content

Commit 4c355aa

Browse files
committed
Feat: TRAIN MODEL on join table
1 parent 87aa993 commit 4c355aa

File tree

7 files changed

+292
-41
lines changed

7 files changed

+292
-41
lines changed

traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ void importSynopsis(String synopsisName, boolean isExternal, JSONObject exportMe
142142

143143
void deleteTasks(Integer cnt) throws CatalogException;
144144

145+
/* Join */
146+
public MTable createJoinTable(List<String> schemaNames, List<String> tableNames,
147+
List<List<String>> columnNames, List<RelDataType> dataTypes,
148+
String joinQuery) throws CatalogException;
149+
150+
public void dropJoinTable(String schemaName, String joinTableName) throws CatalogException;
151+
145152
/* Common */
146153
void close();
147154
}

traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Comparator;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.UUID;
2425
import javax.jdo.PersistenceManager;
2526
import javax.jdo.Query;
2627
import javax.jdo.Transaction;
@@ -30,6 +31,7 @@
3031
import org.json.simple.JSONArray;
3132
import org.json.simple.JSONObject;
3233
import traindb.catalog.pm.MColumn;
34+
import traindb.catalog.pm.MJoin;
3335
import traindb.catalog.pm.MModel;
3436
import traindb.catalog.pm.MModeltype;
3537
import traindb.catalog.pm.MQueryLog;
@@ -149,12 +151,89 @@ private MTable addTable(String schemaName, String tableName, List<String> column
149151
return mTable;
150152
}
151153

154+
@Override
155+
public MTable createJoinTable(List<String> schemaNames, List<String> tableNames,
156+
List<List<String>> columnNames, List<RelDataType> dataTypes,
157+
String joinQuery) throws CatalogException {
158+
List<Long> srcTableIds = new ArrayList<>();
159+
for (int i = 0; i < tableNames.size(); i++) {
160+
MTable table = addTable(schemaNames.get(i), tableNames.get(i), columnNames.get(i),
161+
dataTypes.get(i), "TABLE");
162+
srcTableIds.add(table.getId());
163+
}
164+
165+
UUID joinTableId = UUID.randomUUID();
166+
String joinTableName = "__JOIN_" + joinTableId;
167+
try {
168+
// use first table's schema as join table's schema
169+
MTable joinTable = new MTable(joinTableName, "JOIN", getSchema(schemaNames.get(0)));
170+
pm.makePersistent(joinTable);
171+
MTableExt joinTableExt = new MTableExt(joinTableName, "JOIN", joinQuery, joinTable);
172+
pm.makePersistent(joinTableExt);
173+
174+
for (int i = 0; i < tableNames.size(); i++) {
175+
List<String> colnames = columnNames.get(i);
176+
RelDataType relDataType = dataTypes.get(i);
177+
for (String colname : colnames) {
178+
RelDataTypeField field = relDataType.getField(colname, true, false);
179+
MColumn mColumn = new MColumn(field.getName(),
180+
field.getType().getSqlTypeName().getJdbcOrdinal(),
181+
field.getType().getPrecision(),
182+
field.getType().getScale(),
183+
field.getType().isNullable(),
184+
joinTable);
185+
186+
pm.makePersistent(mColumn);
187+
}
188+
}
189+
190+
for (int i = 0; i < tableNames.size(); i++) {
191+
MJoin join = new MJoin(joinTable.getId(), srcTableIds.get(i), columnNames.get(i));
192+
pm.makePersistent(join);
193+
}
194+
195+
return joinTable;
196+
} catch (RuntimeException e) {
197+
throw new CatalogException("failed to create join table", e);
198+
}
199+
}
200+
201+
@Override
202+
public void dropJoinTable(String schemaName, String joinTableName) throws CatalogException {
203+
Transaction tx = pm.currentTransaction();
204+
try {
205+
tx.begin();
206+
207+
MTable joinTable = getTable(schemaName, joinTableName);
208+
if (joinTable == null) {
209+
return;
210+
}
211+
212+
Query query = pm.newQuery(MJoin.class);
213+
setFilterPatterns(query, ImmutableMap.of("join_table_id", joinTable.getId()));
214+
Collection<MJoin> mJoins = (List<MJoin>) query.execute();
215+
pm.deletePersistentAll(mJoins);
216+
pm.deletePersistent(joinTable);
217+
218+
tx.commit();
219+
} catch (RuntimeException e) {
220+
throw new CatalogException("failed to drop join table '" + joinTableName + "'", e);
221+
} finally {
222+
if (tx.isActive()) {
223+
tx.rollback();
224+
}
225+
}
226+
}
227+
152228
@Override
153229
public MModel trainModel(
154230
String modeltypeName, String modelName, String schemaName, String tableName,
155231
List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
156232
@Nullable Long trainedRows, @Nullable String options) throws CatalogException {
157-
MTable mTable = addTable(schemaName, tableName, columnNames, dataType, "TABLE");
233+
MTable mTable = getTable(schemaName, tableName);
234+
if (mTable == null) {
235+
mTable = addTable(schemaName, tableName, columnNames, dataType, "TABLE");
236+
}
158237
try {
159238
MModeltype mModeltype = getModeltype(modeltypeName);
160239
MModel mModel = new MModel(
@@ -194,6 +273,7 @@ public void dropModel(String name) throws CatalogException {
194273
pm.deletePersistent(mTable);
195274
tx.commit();
196275
}
276+
dropJoinTable(baseSchema, baseTable);
197277

198278
Collection<MModel> baseSchemaModels = getModels(ImmutableMap.of("schema_name", baseSchema));
199279
if (baseSchemaModels == null || baseSchemaModels.size() == 0) {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.catalog.pm;
16+
17+
import java.util.List;
18+
import javax.jdo.annotations.PersistenceCapable;
19+
import javax.jdo.annotations.Persistent;
20+
21+
@PersistenceCapable
22+
public final class MJoin {
23+
24+
@Persistent
25+
private long join_table_id;
26+
27+
@Persistent
28+
private long src_table_id;
29+
30+
@Persistent
31+
private List<String> columns;
32+
33+
public MJoin(long joinTableId, long srcTableId, List<String> columns) {
34+
this.join_table_id = joinTableId;
35+
this.src_table_id = srcTableId;
36+
this.columns = columns;
37+
}
38+
39+
public List<String> getColumnNames() {
40+
return columns;
41+
}
42+
43+
}

traindb-catalog/src/main/java/traindb/catalog/pm/MTable.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ public MTable(String name, String type, MSchema schema) {
5656
this.schema = schema;
5757
}
5858

59+
public long getId() {
60+
return id;
61+
}
62+
5963
public String getTableName() {
6064
return table_name;
6165
}

0 commit comments

Comments
 (0)