|
21 | 21 | import java.util.Comparator;
|
22 | 22 | import java.util.List;
|
23 | 23 | import java.util.Map;
|
| 24 | +import java.util.UUID; |
24 | 25 | import javax.jdo.PersistenceManager;
|
25 | 26 | import javax.jdo.Query;
|
26 | 27 | import javax.jdo.Transaction;
|
|
30 | 31 | import org.json.simple.JSONArray;
|
31 | 32 | import org.json.simple.JSONObject;
|
32 | 33 | import traindb.catalog.pm.MColumn;
|
| 34 | +import traindb.catalog.pm.MJoin; |
33 | 35 | import traindb.catalog.pm.MModel;
|
34 | 36 | import traindb.catalog.pm.MModeltype;
|
35 | 37 | import traindb.catalog.pm.MQueryLog;
|
@@ -149,12 +151,89 @@ private MTable addTable(String schemaName, String tableName, List<String> column
|
149 | 151 | return mTable;
|
150 | 152 | }
|
151 | 153 |
|
| 154 | + @Override |
| 155 | + public MTable createJoinTable(List<String> schemaNames, List<String> tableNames, |
| 156 | + List<List<String>> columnNames, List<RelDataType> dataTypes, |
| 157 | + String joinQuery) throws CatalogException { |
| 158 | + List<Long> srcTableIds = new ArrayList<>(); |
| 159 | + for (int i = 0; i < tableNames.size(); i++) { |
| 160 | + MTable table = addTable(schemaNames.get(i), tableNames.get(i), columnNames.get(i), |
| 161 | + dataTypes.get(i), "TABLE"); |
| 162 | + srcTableIds.add(table.getId()); |
| 163 | + } |
| 164 | + |
| 165 | + UUID joinTableId = UUID.randomUUID(); |
| 166 | + String joinTableName = "__JOIN_" + joinTableId; |
| 167 | + try { |
| 168 | + // use first table's schema as join table's schema |
| 169 | + MTable joinTable = new MTable(joinTableName, "JOIN", getSchema(schemaNames.get(0))); |
| 170 | + pm.makePersistent(joinTable); |
| 171 | + MTableExt joinTableExt = new MTableExt(joinTableName, "JOIN", joinQuery, joinTable); |
| 172 | + pm.makePersistent(joinTableExt); |
| 173 | + |
| 174 | + for (int i = 0; i < tableNames.size(); i++) { |
| 175 | + List<String> colnames = columnNames.get(i); |
| 176 | + RelDataType relDataType = dataTypes.get(i); |
| 177 | + for (String colname : colnames) { |
| 178 | + RelDataTypeField field = relDataType.getField(colname, true, false); |
| 179 | + MColumn mColumn = new MColumn(field.getName(), |
| 180 | + field.getType().getSqlTypeName().getJdbcOrdinal(), |
| 181 | + field.getType().getPrecision(), |
| 182 | + field.getType().getScale(), |
| 183 | + field.getType().isNullable(), |
| 184 | + joinTable); |
| 185 | + |
| 186 | + pm.makePersistent(mColumn); |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + for (int i = 0; i < tableNames.size(); i++) { |
| 191 | + MJoin join = new MJoin(joinTable.getId(), srcTableIds.get(i), columnNames.get(i)); |
| 192 | + pm.makePersistent(join); |
| 193 | + } |
| 194 | + |
| 195 | + return joinTable; |
| 196 | + } catch (RuntimeException e) { |
| 197 | + throw new CatalogException("failed to create join table", e); |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + @Override |
| 202 | + public void dropJoinTable(String schemaName, String joinTableName) throws CatalogException { |
| 203 | + Transaction tx = pm.currentTransaction(); |
| 204 | + try { |
| 205 | + tx.begin(); |
| 206 | + |
| 207 | + MTable joinTable = getTable(schemaName, joinTableName); |
| 208 | + if (joinTable == null) { |
| 209 | + return; |
| 210 | + } |
| 211 | + |
| 212 | + Query query = pm.newQuery(MJoin.class); |
| 213 | + setFilterPatterns(query, ImmutableMap.of("join_table_id", joinTable.getId())); |
| 214 | + Collection<MJoin> mJoins = (List<MJoin>) query.execute(); |
| 215 | + pm.deletePersistentAll(mJoins); |
| 216 | + pm.deletePersistent(joinTable); |
| 217 | + |
| 218 | + tx.commit(); |
| 219 | + } catch (RuntimeException e) { |
| 220 | + throw new CatalogException("failed to drop join table '" + joinTableName + "'", e); |
| 221 | + } finally { |
| 222 | + if (tx.isActive()) { |
| 223 | + tx.rollback(); |
| 224 | + } |
| 225 | + } |
| 226 | + } |
| 227 | + |
152 | 228 | @Override
|
153 | 229 | public MModel trainModel(
|
154 | 230 | String modeltypeName, String modelName, String schemaName, String tableName,
|
155 | 231 | List<String> columnNames, RelDataType dataType, @Nullable Long baseTableRows,
|
156 | 232 | @Nullable Long trainedRows, @Nullable String options) throws CatalogException {
|
157 |
| - MTable mTable = addTable(schemaName, tableName, columnNames, dataType, "TABLE"); |
| 233 | + MTable mTable = getTable(schemaName, tableName); |
| 234 | + if (mTable == null) { |
| 235 | + mTable = addTable(schemaName, tableName, columnNames, dataType, "TABLE"); |
| 236 | + } |
158 | 237 | try {
|
159 | 238 | MModeltype mModeltype = getModeltype(modeltypeName);
|
160 | 239 | MModel mModel = new MModel(
|
@@ -194,6 +273,7 @@ public void dropModel(String name) throws CatalogException {
|
194 | 273 | pm.deletePersistent(mTable);
|
195 | 274 | tx.commit();
|
196 | 275 | }
|
| 276 | + dropJoinTable(baseSchema, baseTable); |
197 | 277 |
|
198 | 278 | Collection<MModel> baseSchemaModels = getModels(ImmutableMap.of("schema_name", baseSchema));
|
199 | 279 | if (baseSchemaModels == null || baseSchemaModels.size() == 0) {
|
|
0 commit comments