Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/134461.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 134461
summary: Propagates filter() to aggregation functions' surrogates
area: Aggregations
type: bug
issues:
- 134380
180 changes: 180 additions & 0 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -3145,3 +3145,183 @@ FROM employees
m:datetime | x:integer | d:boolean
1999-04-30T00:00:00.000Z | 2 | true
;

sumWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

FROM employees
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
| STATS sum1 = SUM(1),
sum2 = SUM(1) WHERE emp_no == 10080,
sum3 = SUM(1) WHERE emp_no < 10080,
sum4 = SUM(1) WHERE emp_no >= 10080,
sum5 = SUM(agg_metric),
sum6 = SUM(agg_metric) WHERE emp_no == 10080
;

sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
100 | 1 | 79 | 21 | 100.0 | 1.0
;

weightedAvgWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
| MV_EXPAND x
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
;

w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
;

maxWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS max1 = MAX(agg_metric) WHERE x <= 3,
max2 = MAX(agg_metric),
max3 = MAX(x),
max4 = MAX(x) WHERE x > 3
;

max1:double | max2:double | max3:integer | max4:integer
3.0 | 5.0 | 5 | 5
;

minWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS min1 = MIN(agg_metric) WHERE x <= 3,
min2 = MIN(agg_metric),
min3 = MIN(x),
min4 = MIN(x) WHERE x > 3
;

min1:double | min2:double | min3:integer | min4:integer
1.0 | 1.0 | 1 | 4
;

countWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS count1 = COUNT(x) WHERE x >= 3,
count2 = COUNT(x),
count3 = COUNT(agg_metric),
count4 = COUNT(agg_metric) WHERE x >=3,
count5 = COUNT(4) WHERE x >= 3,
count6 = COUNT(*) WHERE x >= 3,
count7 = COUNT([1,2,3]) WHERE x >= 3,
count8 = COUNT([1,2,3])
;

count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
3 | 5 | 5 | 3 | 3 | 3 | 9 | 15
;

countDistinctWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
count2 = COUNT_DISTINCT(x),
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
count4 = COUNT_DISTINCT(1)
;

count1:long | count2:long | count3:long | count4:long
3 | 5 | 1 | 1
;

avgWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS avg1 = AVG(x) WHERE x <= 3,
avg2 = AVG(x),
avg3 = AVG(agg_metric) WHERE x <=3,
avg4 = AVG(agg_metric)
;

avg1:double | avg2:double | avg3:double | avg4:double
2.0 | 3.0 | 2.0 | 3.0
;

percentileWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
percentile2 = PERCENTILE(x, 50)
;

percentile1:double | percentile2:double
2.0 | 3.0
;

medianWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS median1 = MEDIAN(x) WHERE x <= 3,
median2 = MEDIAN(x),
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
median4 = MEDIAN([5,6,7,8,9])
;

median1:double | median2:double | median3:double | median4:double
2.0 | 3.0 | 7.0 | 7.0
;

medianAbsoluteDeviationWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([5,6,7,8,9]) WHERE x <= 3,
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([5,6,7,8,9])
;

median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
1.0 | 1.0 | 1.0 | 1.0
;

topWithConditions
required_capability: stats_with_filtered_surrogate_fixed

FROM employees
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
;

min1:integer | min2:integer | max1:integer | max2:integer
10011 | [10011, 10012] | 10079 | [10079, 10078]
;
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,14 @@ public enum Cap {
*/
FN_PRESENT,

/**
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
* expression is replaced by something else on planning
* e.g. STATS SUM(1) WHERE x==3 is replaced by
* MV_SUM(const)*COUNT(* WHERE x == 3).
*/
STATS_WITH_FILTERED_SURROGATE_FIXED,

/**
* TO_DENSE_VECTOR function.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ public Expression surrogate() {
var s = source();
var field = field();
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
return new Sum(
s,
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
filter()
);
}

if (field.foldable()) {
Expand All @@ -169,7 +173,7 @@ public Expression surrogate() {
return new Mul(
s,
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ public final AggregatorFunctionSupplier supplier() {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
return new Max(
source(),
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
filter()
);
}
return field().foldable() ? new MvMax(source(), field()) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,6 @@ public Expression surrogate() {

return field.foldable()
? new MvMedian(s, new ToDouble(s, field))
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ public final AggregatorFunctionSupplier supplier() {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
return new Min(
source(),
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
filter()
);
}
return field().foldable() ? new MvMin(source(), field()) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_dou
this(source, field, Literal.TRUE, SummationMode.COMPENSATED_LITERAL);
}

public Sum(
Source source,
@Param(name = "number", type = { "aggregate_metric_double", "double", "integer", "long" }) Expression field,
Expression filter
) {
this(source, field, filter, SummationMode.COMPENSATED_LITERAL);
}

public Sum(Source source, Expression field, Expression filter, Expression summationMode) {
super(source, field, filter, List.of(summationMode));
this.summationMode = summationMode;
Expand Down Expand Up @@ -163,6 +171,6 @@ public Expression surrogate() {
}

// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ public Expression surrogate() {
var s = source();
if (orderField() instanceof Literal && limitField() instanceof Literal && limitValue() == 1) {
if (orderValue()) {
return new Min(s, field());
return new Min(s, field(), filter());
} else {
return new Max(s, field());
return new Max(s, field(), filter());
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ public Expression surrogate() {
return new MvAvg(s, field);
}
if (weight.foldable()) {
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
} else {
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,31 @@ public void testFold() {
}, this::evaluate);
}

public void testSurrogateHasFilter() {
Expression expression = randomFrom(
buildLiteralExpression(testCase),
buildDeepCopyOfFieldExpression(testCase),
buildFieldExpression(testCase)
);
Comment on lines +185 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think there's an overload with suppliers here. Something like:

Suggested change
Expression expression = randomFrom(
buildLiteralExpression(testCase),
buildDeepCopyOfFieldExpression(testCase),
buildFieldExpression(testCase)
);
Expression expression = randomFrom(
random(),
() -> buildLiteralExpression(testCase),
() -> buildDeepCopyOfFieldExpression(testCase),
() -> buildFieldExpression(testCase)
);

assumeTrue("expression should have no type errors", expression.typeResolved().resolved());

if (expression instanceof AggregateFunction && expression instanceof SurrogateExpression) {
var filter = ((AggregateFunction) expression).filter();

if (filter != null) {
var surrogate = ((SurrogateExpression) expression).surrogate();

if (surrogate != null) {
surrogate.forEachDown(AggregateFunction.class, child -> {
var surrogateFilter = child.filter();
assertEquals(filter, surrogateFilter);
});
}
}
}
}

private void aggregateSingleMode(Expression expression) {
Object result;
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
Expand Down