19
19
import java .util .ArrayList ;
20
20
import java .util .Collection ;
21
21
import java .util .Comparator ;
22
+ import java .util .HashSet ;
23
+ import java .util .Iterator ;
22
24
import java .util .List ;
23
25
import java .util .Map ;
26
+ import java .util .Set ;
27
+ import java .util .UUID ;
28
+ import java .util .stream .Collectors ;
24
29
import javax .jdo .PersistenceManager ;
25
30
import javax .jdo .Query ;
26
31
import javax .jdo .Transaction ;
30
35
import org .json .simple .JSONArray ;
31
36
import org .json .simple .JSONObject ;
32
37
import traindb .catalog .pm .MColumn ;
38
+ import traindb .catalog .pm .MJoin ;
33
39
import traindb .catalog .pm .MModel ;
34
40
import traindb .catalog .pm .MModeltype ;
35
41
import traindb .catalog .pm .MQueryLog ;
@@ -115,11 +121,8 @@ public void dropModeltype(String name) throws CatalogException {
115
121
}
116
122
}
117
123
118
- @ Override
119
- public MModel trainModel (
120
- String modeltypeName , String modelName , String schemaName , String tableName ,
121
- List <String > columnNames , RelDataType dataType , @ Nullable Long baseTableRows ,
122
- @ Nullable Long trainedRows , @ Nullable String options ) throws CatalogException {
124
+ private MTable addTable (String schemaName , String tableName , List <String > columnNames ,
125
+ RelDataType relDataType , String tableType ) throws CatalogException {
123
126
MTable mTable ;
124
127
try {
125
128
MSchema mSchema = getSchema (schemaName );
@@ -130,11 +133,11 @@ public MModel trainModel(
130
133
131
134
mTable = getTable (schemaName , tableName );
132
135
if (mTable == null ) {
133
- mTable = new MTable (tableName , "TABLE" , mSchema );
136
+ mTable = new MTable (tableName , tableType , mSchema );
134
137
pm .makePersistent (mTable );
135
138
136
- List <RelDataTypeField > fields = dataType .getFieldList ();
137
- for (int i = 0 ; i < dataType .getFieldCount (); i ++) {
139
+ List <RelDataTypeField > fields = relDataType .getFieldList ();
140
+ for (int i = 0 ; i < relDataType .getFieldCount (); i ++) {
138
141
RelDataTypeField field = fields .get (i );
139
142
MColumn mColumn = new MColumn (field .getName (),
140
143
field .getType ().getSqlTypeName ().getJdbcOrdinal (),
@@ -146,7 +149,153 @@ public MModel trainModel(
146
149
pm .makePersistent (mColumn );
147
150
}
148
151
}
152
+ } catch (RuntimeException e ) {
153
+ throw new CatalogException ("failed to add table '" + schemaName + "." + tableName + "'" , e );
154
+ }
155
+ return mTable ;
156
+ }
157
+
158
+ @ Override
159
+ public MTable createJoinTable (List <String > schemaNames , List <String > tableNames ,
160
+ List <List <String >> columnNames , List <RelDataType > dataTypes ,
161
+ String joinCondition ) throws CatalogException {
162
+ List <Long > srcTableIds = new ArrayList <>();
163
+ for (int i = 0 ; i < tableNames .size (); i ++) {
164
+ MTable table = addTable (schemaNames .get (i ), tableNames .get (i ), columnNames .get (i ),
165
+ dataTypes .get (i ), "TABLE" );
166
+ srcTableIds .add (table .getId ());
167
+ }
168
+
169
+ UUID joinTableId = UUID .randomUUID ();
170
+ String joinTableName = "__JOIN_" + joinTableId ;
171
+ try {
172
+ // use first table's schema as join table's schema
173
+ MTable joinTable = new MTable (joinTableName , "JOIN" , getSchema (schemaNames .get (0 )));
174
+ pm .makePersistent (joinTable );
175
+ MTableExt joinTableExt = new MTableExt (joinTableName , "JOIN" , joinCondition , joinTable );
176
+ pm .makePersistent (joinTableExt );
177
+
178
+ for (int i = 0 ; i < tableNames .size (); i ++) {
179
+ List <String > colnames = columnNames .get (i );
180
+ RelDataType relDataType = dataTypes .get (i );
181
+ for (String colname : colnames ) {
182
+ RelDataTypeField field = relDataType .getField (colname , true , false );
183
+ MColumn mColumn = new MColumn (field .getName (),
184
+ field .getType ().getSqlTypeName ().getJdbcOrdinal (),
185
+ field .getType ().getPrecision (),
186
+ field .getType ().getScale (),
187
+ field .getType ().isNullable (),
188
+ joinTable );
189
+
190
+ pm .makePersistent (mColumn );
191
+ }
192
+ }
193
+
194
+ for (int i = 0 ; i < tableNames .size (); i ++) {
195
+ MJoin join = new MJoin (joinTable .getId (), srcTableIds .get (i ), columnNames .get (i ));
196
+ pm .makePersistent (join );
197
+ }
198
+
199
+ return joinTable ;
200
+ } catch (RuntimeException e ) {
201
+ throw new CatalogException ("failed to create join table" , e );
202
+ }
203
+ }
149
204
205
+ @ Override
206
+ public void dropJoinTable (String schemaName , String joinTableName ) throws CatalogException {
207
+ Transaction tx = pm .currentTransaction ();
208
+ try {
209
+ tx .begin ();
210
+
211
+ MTable joinTable = getTable (schemaName , joinTableName );
212
+ if (joinTable == null ) {
213
+ return ;
214
+ }
215
+
216
+ Query query = pm .newQuery (MJoin .class );
217
+ setFilterPatterns (query , ImmutableMap .of ("join_table_id" , joinTable .getId ()));
218
+ Collection <MJoin > mJoins = (List <MJoin >) query .execute ();
219
+ pm .deletePersistentAll (mJoins );
220
+ pm .deletePersistent (joinTable );
221
+
222
+ tx .commit ();
223
+ } catch (RuntimeException e ) {
224
+ throw new CatalogException ("failed to drop join table '" + joinTableName + "'" , e );
225
+ } finally {
226
+ if (tx .isActive ()) {
227
+ tx .rollback ();
228
+ }
229
+ }
230
+ }
231
+
232
+ @ Override
233
+ public Collection <MSynopsis > getJoinSynopses (
234
+ List <Long > baseTableIds , Map <Long , List <String >> columnNames , String joinCondition )
235
+ throws CatalogException {
236
+ try {
237
+ List <Long > joinTableIds = null ;
238
+ for (Long tid : baseTableIds ) {
239
+ Query query = pm .newQuery (MJoin .class );
240
+ setFilterPatterns (query , ImmutableMap .of ("src_table_id" , tid ));
241
+ List <MJoin > mJoins = (List <MJoin >) query .execute ();
242
+ List <Long > ids = mJoins .stream ()
243
+ .filter (obj -> obj .containsColumnNames (columnNames .get (tid )))
244
+ .map (MJoin ::getJoinTableId ).collect (Collectors .toList ());
245
+ if (joinTableIds == null ) {
246
+ joinTableIds = ids ;
247
+ } else {
248
+ joinTableIds .retainAll (ids );
249
+ }
250
+ }
251
+ if (joinTableIds == null ) {
252
+ return null ;
253
+ }
254
+
255
+ List <MSynopsis > joinSynopses = new ArrayList <>();
256
+ for (Long joinTableId : joinTableIds ) {
257
+ Collection <MTable > joinTable = getTables (ImmutableMap .of ("id" , joinTableId ));
258
+ for (MTable jt : joinTable ) {
259
+ Collection <MTableExt > tableExts = jt .getTableExts ();
260
+ if (tableExts == null || tableExts .isEmpty ()) {
261
+ continue ;
262
+ }
263
+ for (MTableExt tableExt : tableExts ) {
264
+ if (tableExt .getExternalTableUri ().contains (joinCondition )) {
265
+ Collection <MSynopsis > synopses =
266
+ getAllSynopses (jt .getSchema ().getSchemaName (), jt .getTableName ());
267
+ joinSynopses .addAll (synopses );
268
+ break ;
269
+ }
270
+ }
271
+ }
272
+ }
273
+ return joinSynopses ;
274
+ } catch (RuntimeException e ) {
275
+ throw new CatalogException ("failed to get synopses" , e );
276
+ }
277
+
278
+ }
279
+
280
+ @ Override
281
+ public MModel trainModel (
282
+ String modeltypeName , String modelName , String schemaName , String tableName ,
283
+ List <String > columnNames , RelDataType dataType , @ Nullable Long baseTableRows ,
284
+ @ Nullable Long trainedRows , @ Nullable String options ) throws CatalogException {
285
+ MTable mTable = getTable (schemaName , tableName );
286
+ if (mTable == null ) {
287
+ mTable = addTable (schemaName , tableName , columnNames , dataType , "TABLE" );
288
+ }
289
+ try {
290
+
291
+ Set <String > columnNameSet = new HashSet <>();
292
+ Iterator <String > iter = columnNames .iterator ();
293
+ while (iter .hasNext ()) {
294
+ String colname = iter .next ();
295
+ if (!columnNameSet .add (colname )) {
296
+ iter .remove ();
297
+ }
298
+ }
150
299
MModeltype mModeltype = getModeltype (modeltypeName );
151
300
MModel mModel = new MModel (
152
301
mModeltype , modelName , schemaName , tableName , columnNames ,
@@ -185,6 +334,7 @@ public void dropModel(String name) throws CatalogException {
185
334
pm .deletePersistent (mTable );
186
335
tx .commit ();
187
336
}
337
+ dropJoinTable (baseSchema , baseTable );
188
338
189
339
Collection <MModel > baseSchemaModels = getModels (ImmutableMap .of ("schema_name" , baseSchema ));
190
340
if (baseSchemaModels == null || baseSchemaModels .size () == 0 ) {
0 commit comments