Skip to content

Commit 42b033f

Browse files
committed
Add optimized path for intermediate values aggregator
1 parent f135998 commit 42b033f

File tree

6 files changed

+186
-116
lines changed

6 files changed

+186
-116
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.util.BytesRef;
1313
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
1414
import org.elasticsearch.common.util.BigArrays;
15+
import org.elasticsearch.compute.aggregation.AggregatorFunction;
1516
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
1617
import org.elasticsearch.compute.aggregation.AggregatorMode;
1718
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
@@ -57,8 +58,8 @@
5758
* Benchmark for the {@code VALUES} aggregator that supports grouping by many many
5859
* many values.
5960
*/
60-
@Warmup(iterations = 5)
61-
@Measurement(iterations = 7)
61+
@Warmup(iterations = 1)
62+
@Measurement(iterations = 2)
6263
@BenchmarkMode(Mode.AverageTime)
6364
@OutputTimeUnit(TimeUnit.MILLISECONDS)
6465
@State(Scope.Thread)
@@ -107,22 +108,22 @@ static void selfTest() {
107108
private static final String INT = "int";
108109
private static final String LONG = "long";
109110

110-
@Param({ "1", "1000", /*"1000000"*/ })
111+
@Param({ "1", "1000", "1000000" })
111112
public int groups;
112113

113-
@Param({ BYTES_REF, INT, LONG })
114+
@Param({ BYTES_REF })
114115
public String dataType;
115116

116-
private static Operator operator(DriverContext driverContext, int groups, String dataType) {
117+
private static Operator operator(DriverContext driverContext, int groups, String dataType, AggregatorMode mode) {
117118
if (groups == 1) {
118119
return new AggregationOperator(
119-
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
120+
List.of(supplier(dataType).aggregatorFactory(mode, List.of(0)).apply(driverContext)),
120121
driverContext
121122
);
122123
}
123124
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
124125
return new HashAggregationOperator(
125-
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
126+
List.of(supplier(dataType).groupingAggregatorFactory(mode, List.of(1))),
126127
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
127128
driverContext
128129
) {
@@ -341,13 +342,21 @@ public void run() {
341342

342343
private static void run(int groups, String dataType, int opCount) {
343344
DriverContext driverContext = driverContext();
344-
try (Operator operator = operator(driverContext, groups, dataType)) {
345-
Page page = page(groups, dataType);
346-
for (int i = 0; i < opCount; i++) {
347-
operator.addInput(page.shallowCopy());
345+
try (Operator finalAggregator = operator(driverContext, groups, dataType, AggregatorMode.FINAL)) {
346+
try (Operator initialAggregator = operator(driverContext, groups, dataType, AggregatorMode.INITIAL)) {
347+
Page rawPage = page(groups, dataType);
348+
for (int i = 0; i < opCount; i++) {
349+
initialAggregator.addInput(rawPage.shallowCopy());
350+
}
351+
initialAggregator.finish();
352+
Page intermediatePage = initialAggregator.getOutput();
353+
for (int i = 0; i < opCount; i++) {
354+
finalAggregator.addInput(intermediatePage.shallowCopy());
355+
}
356+
finalAggregator.finish();
357+
Page outputPage = finalAggregator.getOutput();
358+
checkExpected(groups, dataType, outputPage);
348359
}
349-
operator.finish();
350-
checkExpected(groups, dataType, operator.getOutput());
351360
}
352361
}
353362

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -609,63 +609,80 @@ private MethodSpec addIntermediateInput() {
609609
.collect(joining(" && "))
610610
);
611611
}
612-
if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
613-
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
614-
}
615-
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
616-
{
617-
builder.addStatement("int groupId = groups.getInt(groupPosition)");
618-
if (aggState.declaredType().isPrimitive()) {
619-
if (warnExceptions.isEmpty()) {
620-
assert intermediateState.size() == 2;
621-
assert intermediateState.get(1).name().equals("seen");
622-
builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
623-
} else {
624-
assert intermediateState.size() == 3;
625-
assert intermediateState.get(1).name().equals("seen");
626-
assert intermediateState.get(2).name().equals("failed");
627-
builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
628-
{
629-
builder.addStatement("state.setFailed(groupId)");
612+
var bulkCombineIntermediateMethod = optionalStaticMethod(
613+
declarationType,
614+
requireVoidType(),
615+
requireName("combineIntermediate"),
616+
requireArgs(
617+
Stream.of(
618+
Stream.of(aggState.declaredType(), TypeName.INT, INT_VECTOR), // aggState, positionOffset, groupIds
619+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType)
620+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
621+
)
622+
);
623+
if (bulkCombineIntermediateMethod != null) {
624+
var states = intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::name).collect(Collectors.joining(","));
625+
builder.addStatement("$T.combineIntermediate(state, positionOffset, groups," + states + ")", declarationType);
626+
} else {
627+
if (intermediateState.stream()
628+
.map(AggregatorImplementer.IntermediateStateDesc::elementType)
629+
.anyMatch(n -> n.equals("BYTES_REF"))) {
630+
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
631+
}
632+
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
633+
{
634+
builder.addStatement("int groupId = groups.getInt(groupPosition)");
635+
if (aggState.declaredType().isPrimitive()) {
636+
if (warnExceptions.isEmpty()) {
637+
assert intermediateState.size() == 2;
638+
assert intermediateState.get(1).name().equals("seen");
639+
builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
640+
} else {
641+
assert intermediateState.size() == 3;
642+
assert intermediateState.get(1).name().equals("seen");
643+
assert intermediateState.get(2).name().equals("failed");
644+
builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
645+
{
646+
builder.addStatement("state.setFailed(groupId)");
647+
}
648+
builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
630649
}
631-
builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
632-
}
633650

634-
warningsBlock(builder, () -> {
635-
var name = intermediateState.get(0).name();
636-
var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
637-
builder.addStatement(
638-
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
651+
warningsBlock(builder, () -> {
652+
var name = intermediateState.get(0).name();
653+
var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
654+
builder.addStatement(
655+
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
656+
declarationType,
657+
name,
658+
vectorAccessor
659+
);
660+
});
661+
builder.endControlFlow();
662+
} else {
663+
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
664+
requireStaticMethod(
639665
declarationType,
640-
name,
641-
vectorAccessor
666+
requireVoidType(),
667+
requireName("combineIntermediate"),
668+
requireArgs(
669+
Stream.of(
670+
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
671+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
672+
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
673+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
674+
)
675+
);
676+
builder.addStatement(
677+
"$T.combineIntermediate(state, groupId, "
678+
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
679+
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
680+
+ ")",
681+
declarationType
642682
);
643-
});
683+
}
644684
builder.endControlFlow();
645-
} else {
646-
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
647-
requireStaticMethod(
648-
declarationType,
649-
requireVoidType(),
650-
requireName("combineIntermediate"),
651-
requireArgs(
652-
Stream.of(
653-
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
654-
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
655-
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
656-
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
657-
)
658-
);
659-
660-
builder.addStatement(
661-
"$T.combineIntermediate(state, groupId, "
662-
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
663-
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
664-
+ ")",
665-
declarationType
666-
);
667685
}
668-
builder.endControlFlow();
669686
}
670687
return builder.build();
671688
}

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

Lines changed: 2 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java

Lines changed: 1 addition & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput(
2828
if (valuesOrdinal == null) {
2929
return delegate;
3030
}
31-
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
32-
final IntVector hashIds;
33-
BytesRef spare = new BytesRef();
34-
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
35-
for (int p = 0; p < dict.getPositionCount(); p++) {
36-
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
37-
}
38-
hashIds = hashIdsBuilder.build();
39-
}
31+
final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector());
4032
IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock();
4133
return new GroupingAggregatorFunction.AddInput() {
4234
@Override
@@ -85,17 +77,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) {
8577

8678
@Override
8779
public void add(int positionOffset, IntVector groupIds) {
88-
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
89-
int groupId = groupIds.getInt(groupPosition);
90-
if (ordinalIds.isNull(groupPosition + positionOffset)) {
91-
continue;
92-
}
93-
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
94-
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
95-
for (int v = valuesStart; v < valuesEnd; v++) {
96-
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
97-
}
98-
}
80+
addOrdinalInputBlock(state, positionOffset, groupIds, ordinalIds, hashIds);
9981
}
10082

10183
@Override
@@ -114,15 +96,7 @@ static GroupingAggregatorFunction.AddInput wrapAddInput(
11496
if (valuesOrdinal == null) {
11597
return delegate;
11698
}
117-
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
118-
final IntVector hashIds;
119-
BytesRef spare = new BytesRef();
120-
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
121-
for (int p = 0; p < dict.getPositionCount(); p++) {
122-
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
123-
}
124-
hashIds = hashIdsBuilder.build();
125-
}
99+
final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector());
126100
var ordinalIds = valuesOrdinal.getOrdinalsVector();
127101
return new GroupingAggregatorFunction.AddInput() {
128102
@Override
@@ -157,10 +131,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) {
157131

158132
@Override
159133
public void add(int positionOffset, IntVector groupIds) {
160-
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
161-
int groupId = groupIds.getInt(groupPosition);
162-
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
163-
}
134+
addOrdinalInputVector(state, positionOffset, groupIds, ordinalIds, hashIds);
164135
}
165136

166137
@Override
@@ -169,4 +140,86 @@ public void close() {
169140
}
170141
};
171142
}
143+
144+
static IntVector hashDict(ValuesBytesRefAggregator.GroupingState state, BytesRefVector dict) {
145+
BytesRef scratch = new BytesRef();
146+
try (var hashIdsBuilder = dict.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
147+
for (int p = 0; p < dict.getPositionCount(); p++) {
148+
final long hashId = BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, scratch)));
149+
hashIdsBuilder.appendInt(Math.toIntExact(hashId));
150+
}
151+
return hashIdsBuilder.build();
152+
}
153+
}
154+
155+
static void addOrdinalInputBlock(
156+
ValuesBytesRefAggregator.GroupingState state,
157+
int positionOffset,
158+
IntVector groupIds,
159+
IntBlock ordinalIds,
160+
IntVector hashIds
161+
) {
162+
for (int p = 0; p < groupIds.getPositionCount(); p++) {
163+
final int valuePosition = p + positionOffset;
164+
final int groupId = groupIds.getInt(valuePosition);
165+
final int start = ordinalIds.getFirstValueIndex(valuePosition);
166+
final int end = start + ordinalIds.getValueCount(valuePosition);
167+
for (int i = start; i < end; i++) {
168+
int ord = ordinalIds.getInt(i);
169+
state.addValueOrdinal(groupId, hashIds.getInt(ord));
170+
}
171+
}
172+
}
173+
174+
static void addOrdinalInputVector(
175+
ValuesBytesRefAggregator.GroupingState state,
176+
int positionOffset,
177+
IntVector groupIds,
178+
IntVector ordinalIds,
179+
IntVector hashIds
180+
) {
181+
for (int p = 0; p < groupIds.getPositionCount(); p++) {
182+
int groupId = groupIds.getInt(p);
183+
int ord = ordinalIds.getInt(p + positionOffset);
184+
state.addValueOrdinal(groupId, hashIds.getInt(ord));
185+
}
186+
}
187+
188+
static void combineIntermediateInputValues(
189+
ValuesBytesRefAggregator.GroupingState state,
190+
int positionOffset,
191+
IntVector groupIds,
192+
BytesRefBlock values
193+
) {
194+
BytesRefVector dict = null;
195+
IntBlock ordinals = null;
196+
{
197+
final OrdinalBytesRefBlock asOrdinals = values.asOrdinals();
198+
if (asOrdinals != null) {
199+
dict = asOrdinals.getDictionaryVector();
200+
ordinals = asOrdinals.getOrdinalsBlock();
201+
}
202+
}
203+
if (dict != null && dict.getPositionCount() < groupIds.getPositionCount()) {
204+
try (var hashIds = hashDict(state, dict)) {
205+
IntVector ordinalsVector = ordinals.asVector();
206+
if (ordinalsVector != null) {
207+
addOrdinalInputVector(state, positionOffset, groupIds, ordinalsVector, hashIds);
208+
} else {
209+
addOrdinalInputBlock(state, positionOffset, groupIds, ordinals, hashIds);
210+
}
211+
}
212+
} else {
213+
final BytesRef scratch = new BytesRef();
214+
for (int p = 0; p < groupIds.getPositionCount(); p++) {
215+
final int valuePosition = p + positionOffset;
216+
final int groupId = groupIds.getInt(valuePosition);
217+
final int start = values.getFirstValueIndex(valuePosition);
218+
final int end = start + values.getValueCount(valuePosition);
219+
for (int i = start; i < end; i++) {
220+
state.addValue(groupId, values.getBytesRef(i, scratch));
221+
}
222+
}
223+
}
224+
}
172225
}

0 commit comments

Comments
 (0)