@@ -369,14 +369,26 @@ impl AggregateExec {
369369 new_requirement. extend ( req) ;
370370 new_requirement = collapse_lex_req ( new_requirement) ;
371371
372- let input_order_mode =
373- if indices. len ( ) == groupby_exprs. len ( ) && !indices. is_empty ( ) {
374- InputOrderMode :: Sorted
375- } else if !indices. is_empty ( ) {
376- InputOrderMode :: PartiallySorted ( indices)
377- } else {
378- InputOrderMode :: Linear
379- } ;
372+ // If our aggregation has grouping sets then our base grouping exprs will
373+ // be expanded based on the flags in `group_by.groups` where for each
374+ // group we swap the grouping expr for `null` if the flag is `true`
375+ // That means that each index in `indices` is valid if and only if
376+ // it is not null in every group
377+ let indices: Vec < usize > = indices
378+ . into_iter ( )
379+ . filter ( |idx| group_by. groups . iter ( ) . all ( |group| !group[ * idx] ) )
380+ . collect ( ) ;
381+
382+ let input_order_mode = if indices. len ( ) == groupby_exprs. len ( )
383+ && !indices. is_empty ( )
384+ && group_by. groups . len ( ) == 1
385+ {
386+ InputOrderMode :: Sorted
387+ } else if !indices. is_empty ( ) {
388+ InputOrderMode :: PartiallySorted ( indices)
389+ } else {
390+ InputOrderMode :: Linear
391+ } ;
380392
381393 // construct a map from the input expression to the output expression of the Aggregation group by
382394 let projection_mapping =
@@ -1180,6 +1192,7 @@ mod tests {
11801192 use arrow:: array:: { Float64Array , UInt32Array } ;
11811193 use arrow:: compute:: { concat_batches, SortOptions } ;
11821194 use arrow:: datatypes:: DataType ;
1195+ use arrow_array:: { Float32Array , Int32Array } ;
11831196 use datafusion_common:: {
11841197 assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError ,
11851198 ScalarValue ,
@@ -1195,7 +1208,9 @@ mod tests {
11951208 use datafusion_physical_expr:: expressions:: { lit, OrderSensitiveArrayAgg } ;
11961209 use datafusion_physical_expr:: PhysicalSortExpr ;
11971210
1211+ use crate :: common:: collect;
11981212 use datafusion_physical_expr_common:: aggregate:: create_aggregate_expr;
1213+ use datafusion_physical_expr_common:: expressions:: Literal ;
11991214 use futures:: { FutureExt , Stream } ;
12001215
12011216 // Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -2267,4 +2282,94 @@ mod tests {
22672282 assert_eq ! ( new_agg. schema( ) , aggregate_exec. schema( ) ) ;
22682283 Ok ( ( ) )
22692284 }
2285+
2286+ #[ tokio:: test]
2287+ async fn test_agg_exec_group_by_const ( ) -> Result < ( ) > {
2288+ let schema = Arc :: new ( Schema :: new ( vec ! [
2289+ Field :: new( "a" , DataType :: Float32 , true ) ,
2290+ Field :: new( "b" , DataType :: Float32 , true ) ,
2291+ Field :: new( "const" , DataType :: Int32 , false ) ,
2292+ ] ) ) ;
2293+
2294+ let col_a = col ( "a" , & schema) ?;
2295+ let col_b = col ( "b" , & schema) ?;
2296+ let const_expr = Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 1 ) ) ) ) ;
2297+
2298+ let groups = PhysicalGroupBy :: new (
2299+ vec ! [
2300+ ( col_a, "a" . to_string( ) ) ,
2301+ ( col_b, "b" . to_string( ) ) ,
2302+ ( const_expr, "const" . to_string( ) ) ,
2303+ ] ,
2304+ vec ! [
2305+ (
2306+ Arc :: new( Literal :: new( ScalarValue :: Float32 ( None ) ) ) ,
2307+ "a" . to_string( ) ,
2308+ ) ,
2309+ (
2310+ Arc :: new( Literal :: new( ScalarValue :: Float32 ( None ) ) ) ,
2311+ "b" . to_string( ) ,
2312+ ) ,
2313+ (
2314+ Arc :: new( Literal :: new( ScalarValue :: Int32 ( None ) ) ) ,
2315+ "const" . to_string( ) ,
2316+ ) ,
2317+ ] ,
2318+ vec ! [
2319+ vec![ false , true , true ] ,
2320+ vec![ true , false , true ] ,
2321+ vec![ true , true , false ] ,
2322+ ] ,
2323+ ) ;
2324+
2325+ let aggregates: Vec < Arc < dyn AggregateExpr > > = vec ! [ create_aggregate_expr(
2326+ count_udaf( ) . as_ref( ) ,
2327+ & [ lit( 1 ) ] ,
2328+ & [ datafusion_expr:: lit( 1 ) ] ,
2329+ & [ ] ,
2330+ & [ ] ,
2331+ schema. as_ref( ) ,
2332+ "1" ,
2333+ false ,
2334+ false ,
2335+ ) ?] ;
2336+
2337+ let input_batches = ( 0 ..4 )
2338+ . map ( |_| {
2339+ let a = Arc :: new ( Float32Array :: from ( vec ! [ 0. ; 8192 ] ) ) ;
2340+ let b = Arc :: new ( Float32Array :: from ( vec ! [ 0. ; 8192 ] ) ) ;
2341+ let c = Arc :: new ( Int32Array :: from ( vec ! [ 1 ; 8192 ] ) ) ;
2342+
2343+ RecordBatch :: try_new ( schema. clone ( ) , vec ! [ a, b, c] ) . unwrap ( )
2344+ } )
2345+ . collect ( ) ;
2346+
2347+ let input =
2348+ Arc :: new ( MemoryExec :: try_new ( & [ input_batches] , schema. clone ( ) , None ) ?) ;
2349+
2350+ let aggregate_exec = Arc :: new ( AggregateExec :: try_new (
2351+ AggregateMode :: Partial ,
2352+ groups,
2353+ aggregates. clone ( ) ,
2354+ vec ! [ None ] ,
2355+ input,
2356+ schema,
2357+ ) ?) ;
2358+
2359+ let output =
2360+ collect ( aggregate_exec. execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?) . await ?;
2361+
2362+ let expected = [
2363+ "+-----+-----+-------+----------+" ,
2364+ "| a | b | const | 1[count] |" ,
2365+ "+-----+-----+-------+----------+" ,
2366+ "| | 0.0 | | 32768 |" ,
2367+ "| 0.0 | | | 32768 |" ,
2368+ "| | | 1 | 32768 |" ,
2369+ "+-----+-----+-------+----------+" ,
2370+ ] ;
2371+ assert_batches_sorted_eq ! ( expected, & output) ;
2372+
2373+ Ok ( ( ) )
2374+ }
22702375}
0 commit comments