Skip to content

Commit da6bdc6

Browse files
authored
Merge pull request #57 from traindb-project/dev/issue-56
Feat: TRAIN MODEL on columns from multiple tables
2 parents ff456b7 + dee5be6 commit da6bdc6

15 files changed

+752
-85
lines changed

traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ void importSynopsis(String synopsisName, boolean isExternal, JSONObject exportMe
142142

143143
void deleteTasks(Integer cnt) throws CatalogException;
144144

145+
/* Join */
146+
MTable createJoinTable(List<String> schemaNames, List<String> tableNames,
147+
List<List<String>> columnNames, List<RelDataType> dataTypes,
148+
String joinCondition) throws CatalogException;
149+
150+
void dropJoinTable(String schemaName, String joinTableName) throws CatalogException;
151+
152+
Collection<MSynopsis> getJoinSynopses(
153+
List<Long> baseTableIds, Map<Long, List<String>> columnNames, String joinCondition)
154+
throws CatalogException;
155+
145156
/* Common */
146157
void close();
147158
}

traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java

Lines changed: 158 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@
1919
import java.util.ArrayList;
2020
import java.util.Collection;
2121
import java.util.Comparator;
22+
import java.util.HashSet;
23+
import java.util.Iterator;
2224
import java.util.List;
2325
import java.util.Map;
26+
import java.util.Set;
27+
import java.util.UUID;
28+
import java.util.stream.Collectors;
2429
import javax.jdo.PersistenceManager;
2530
import javax.jdo.Query;
2631
import javax.jdo.Transaction;
@@ -30,6 +35,7 @@
3035
import org.json.simple.JSONArray;
3136
import org.json.simple.JSONObject;
3237
import traindb.catalog.pm.MColumn;
38+
import traindb.catalog.pm.MJoin;
3339
import traindb.catalog.pm.MModel;
3440
import traindb.catalog.pm.MModeltype;
3541
import traindb.catalog.pm.MQueryLog;
@@ -115,11 +121,8 @@ public void dropModeltype(String name) throws CatalogException {
115121
}
116122
}
117123

118-
@Override
119-
public MModel trainModel(
120-
String modeltypeName, String modelName, String schemaName, String tableName,
121-
List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
122-
@Nullable Long trainedRows, @Nullable String options) throws CatalogException {
124+
private MTable addTable(String schemaName, String tableName, List<String> columnNames,
125+
RelDataType relDataType, String tableType) throws CatalogException {
123126
MTable mTable;
124127
try {
125128
MSchema mSchema = getSchema(schemaName);
@@ -130,11 +133,11 @@ public MModel trainModel(
130133

131134
mTable = getTable(schemaName, tableName);
132135
if (mTable == null) {
133-
mTable = new MTable(tableName, "TABLE", mSchema);
136+
mTable = new MTable(tableName, tableType, mSchema);
134137
pm.makePersistent(mTable);
135138

136-
List<RelDataTypeField> fields = dataType.getFieldList();
137-
for (int i = 0; i < dataType.getFieldCount(); i++) {
139+
List<RelDataTypeField> fields = relDataType.getFieldList();
140+
for (int i = 0; i < relDataType.getFieldCount(); i++) {
138141
RelDataTypeField field = fields.get(i);
139142
MColumn mColumn = new MColumn(field.getName(),
140143
field.getType().getSqlTypeName().getJdbcOrdinal(),
@@ -146,7 +149,153 @@ public MModel trainModel(
146149
pm.makePersistent(mColumn);
147150
}
148151
}
152+
} catch (RuntimeException e) {
153+
throw new CatalogException("failed to add table '" + schemaName + "." + tableName + "'", e);
154+
}
155+
return mTable;
156+
}
157+
158+
@Override
159+
public MTable createJoinTable(List<String> schemaNames, List<String> tableNames,
160+
List<List<String>> columnNames, List<RelDataType> dataTypes,
161+
String joinCondition) throws CatalogException {
162+
List<Long> srcTableIds = new ArrayList<>();
163+
for (int i = 0; i < tableNames.size(); i++) {
164+
MTable table = addTable(schemaNames.get(i), tableNames.get(i), columnNames.get(i),
165+
dataTypes.get(i), "TABLE");
166+
srcTableIds.add(table.getId());
167+
}
168+
169+
UUID joinTableId = UUID.randomUUID();
170+
String joinTableName = "__JOIN_" + joinTableId;
171+
try {
172+
// use first table's schema as join table's schema
173+
MTable joinTable = new MTable(joinTableName, "JOIN", getSchema(schemaNames.get(0)));
174+
pm.makePersistent(joinTable);
175+
MTableExt joinTableExt = new MTableExt(joinTableName, "JOIN", joinCondition, joinTable);
176+
pm.makePersistent(joinTableExt);
177+
178+
for (int i = 0; i < tableNames.size(); i++) {
179+
List<String> colnames = columnNames.get(i);
180+
RelDataType relDataType = dataTypes.get(i);
181+
for (String colname : colnames) {
182+
RelDataTypeField field = relDataType.getField(colname, true, false);
183+
MColumn mColumn = new MColumn(field.getName(),
184+
field.getType().getSqlTypeName().getJdbcOrdinal(),
185+
field.getType().getPrecision(),
186+
field.getType().getScale(),
187+
field.getType().isNullable(),
188+
joinTable);
189+
190+
pm.makePersistent(mColumn);
191+
}
192+
}
193+
194+
for (int i = 0; i < tableNames.size(); i++) {
195+
MJoin join = new MJoin(joinTable.getId(), srcTableIds.get(i), columnNames.get(i));
196+
pm.makePersistent(join);
197+
}
198+
199+
return joinTable;
200+
} catch (RuntimeException e) {
201+
throw new CatalogException("failed to create join table", e);
202+
}
203+
}
149204

205+
@Override
206+
public void dropJoinTable(String schemaName, String joinTableName) throws CatalogException {
207+
Transaction tx = pm.currentTransaction();
208+
try {
209+
tx.begin();
210+
211+
MTable joinTable = getTable(schemaName, joinTableName);
212+
if (joinTable == null) {
213+
return;
214+
}
215+
216+
Query query = pm.newQuery(MJoin.class);
217+
setFilterPatterns(query, ImmutableMap.of("join_table_id", joinTable.getId()));
218+
Collection<MJoin> mJoins = (List<MJoin>) query.execute();
219+
pm.deletePersistentAll(mJoins);
220+
pm.deletePersistent(joinTable);
221+
222+
tx.commit();
223+
} catch (RuntimeException e) {
224+
throw new CatalogException("failed to drop join table '" + joinTableName + "'", e);
225+
} finally {
226+
if (tx.isActive()) {
227+
tx.rollback();
228+
}
229+
}
230+
}
231+
232+
@Override
233+
public Collection<MSynopsis> getJoinSynopses(
234+
List<Long> baseTableIds, Map<Long, List<String>> columnNames, String joinCondition)
235+
throws CatalogException {
236+
try {
237+
List<Long> joinTableIds = null;
238+
for (Long tid : baseTableIds) {
239+
Query query = pm.newQuery(MJoin.class);
240+
setFilterPatterns(query, ImmutableMap.of("src_table_id", tid));
241+
List<MJoin> mJoins = (List<MJoin>) query.execute();
242+
List<Long> ids = mJoins.stream()
243+
.filter(obj -> obj.containsColumnNames(columnNames.get(tid)))
244+
.map(MJoin::getJoinTableId).collect(Collectors.toList());
245+
if (joinTableIds == null) {
246+
joinTableIds = ids;
247+
} else {
248+
joinTableIds.retainAll(ids);
249+
}
250+
}
251+
if (joinTableIds == null) {
252+
return null;
253+
}
254+
255+
List<MSynopsis> joinSynopses = new ArrayList<>();
256+
for (Long joinTableId : joinTableIds) {
257+
Collection<MTable> joinTable = getTables(ImmutableMap.of("id", joinTableId));
258+
for (MTable jt : joinTable) {
259+
Collection<MTableExt> tableExts = jt.getTableExts();
260+
if (tableExts == null || tableExts.isEmpty()) {
261+
continue;
262+
}
263+
for (MTableExt tableExt : tableExts) {
264+
if (tableExt.getExternalTableUri().contains(joinCondition)) {
265+
Collection<MSynopsis> synopses =
266+
getAllSynopses(jt.getSchema().getSchemaName(), jt.getTableName());
267+
joinSynopses.addAll(synopses);
268+
break;
269+
}
270+
}
271+
}
272+
}
273+
return joinSynopses;
274+
} catch (RuntimeException e) {
275+
throw new CatalogException("failed to get synopses", e);
276+
}
277+
278+
}
279+
280+
@Override
281+
public MModel trainModel(
282+
String modeltypeName, String modelName, String schemaName, String tableName,
283+
List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
284+
@Nullable Long trainedRows, @Nullable String options) throws CatalogException {
285+
MTable mTable = getTable(schemaName, tableName);
286+
if (mTable == null) {
287+
mTable = addTable(schemaName, tableName, columnNames, dataType, "TABLE");
288+
}
289+
try {
290+
291+
Set<String> columnNameSet = new HashSet<>();
292+
Iterator<String> iter = columnNames.iterator();
293+
while (iter.hasNext()) {
294+
String colname = iter.next();
295+
if (!columnNameSet.add(colname)) {
296+
iter.remove();
297+
}
298+
}
150299
MModeltype mModeltype = getModeltype(modeltypeName);
151300
MModel mModel = new MModel(
152301
mModeltype, modelName, schemaName, tableName, columnNames,
@@ -185,6 +334,7 @@ public void dropModel(String name) throws CatalogException {
185334
pm.deletePersistent(mTable);
186335
tx.commit();
187336
}
337+
dropJoinTable(baseSchema, baseTable);
188338

189339
Collection<MModel> baseSchemaModels = getModels(ImmutableMap.of("schema_name", baseSchema));
190340
if (baseSchemaModels == null || baseSchemaModels.size() == 0) {
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package traindb.catalog.pm;
16+
17+
import java.util.List;
18+
import javax.jdo.annotations.PersistenceCapable;
19+
import javax.jdo.annotations.Persistent;
20+
21+
@PersistenceCapable
22+
public final class MJoin {
23+
24+
@Persistent
25+
private long join_table_id;
26+
27+
@Persistent
28+
private long src_table_id;
29+
30+
@Persistent
31+
private List<String> columns;
32+
33+
public MJoin(long joinTableId, long srcTableId, List<String> columns) {
34+
this.join_table_id = joinTableId;
35+
this.src_table_id = srcTableId;
36+
this.columns = columns;
37+
}
38+
39+
public long getJoinTableId() {
40+
return join_table_id;
41+
}
42+
43+
public List<String> getColumnNames() {
44+
return columns;
45+
}
46+
47+
public boolean containsColumnNames(List<String> columnNames) {
48+
return this.columns.containsAll(columnNames);
49+
}
50+
51+
}

traindb-catalog/src/main/java/traindb/catalog/pm/MTable.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ public MTable(String name, String type, MSchema schema) {
5656
this.schema = schema;
5757
}
5858

59+
public long getId() {
60+
return id;
61+
}
62+
5963
public String getTableName() {
6064
return table_name;
6165
}

0 commit comments

Comments
 (0)