Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/117655.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117655
summary: Add nulls support to Categorize
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
Expand All @@ -31,11 +33,21 @@
* Base BlockHash implementation for {@code Categorize} grouping function.
*/
public abstract class AbstractCategorizeBlockHash extends BlockHash {
protected static final int NULL_ORD = 0;

// TODO: this should probably also take an emitBatchSize
private final int channel;
private final boolean outputPartial;
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;

/**
* Store whether we've seen any {@code null} values.
* <p>
* Null gets the {@link #NULL_ORD} ord.
* </p>
*/
protected boolean seenNull = false;

AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
super(blockFactory);
this.channel = channel;
Expand All @@ -58,12 +70,12 @@ public Block[] getKeys() {

@Override
public IntVector nonEmpty() {
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory);
}

@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
throw new UnsupportedOperationException();
return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
}

@Override
Expand All @@ -76,24 +88,39 @@ public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue target
*/
private Block buildIntermediateBlock() {
if (categorizer.getCategoryCount() == 0) {
return blockFactory.newConstantNullBlock(0);
return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
}
try (BytesStreamOutput out = new BytesStreamOutput()) {
// TODO be more careful here.
out.writeBoolean(seenNull);
out.writeVInt(categorizer.getCategoryCount());
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
category.writeTo(out);
}
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private Block buildFinalBlock() {
BytesRefBuilder scratch = new BytesRefBuilder();

if (seenNull) {
try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
result.appendNull();
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
scratch.copyChars(category.getRegex());
result.appendBytesRef(scratch.get());
scratch.clear();
}
return result.build();
}
}

try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
BytesRefBuilder scratch = new BytesRefBuilder();
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
scratch.copyChars(category.getRegex());
result.appendBytesRef(scratch.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void close() {
/**
* Similar implementation to an Evaluator.
*/
public static final class CategorizeEvaluator implements Releasable {
public final class CategorizeEvaluator implements Releasable {
private final CategorizationAnalyzer analyzer;

private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
Expand Down Expand Up @@ -95,7 +95,8 @@ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
BytesRef vScratch = new BytesRef();
for (int p = 0; p < positionCount; p++) {
if (vBlock.isNull(p)) {
result.appendNull();
seenNull = true;
result.appendInt(NULL_ORD);
continue;
}
int first = vBlock.getFirstValueIndex(p);
Expand Down Expand Up @@ -126,7 +127,12 @@ public IntVector eval(int positionCount, BytesRefVector vVector) {
}

private int process(BytesRef v) {
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
var category = categorizer.computeCategory(v.utf8ToString(), analyzer);
if (category == null) {
seenNull = true;
return NULL_ORD;
}
return category.getId() + 1;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,19 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
return;
}
BytesRefBlock categorizerState = page.getBlock(channel());
if (categorizerState.areAllValuesNull()) {
seenNull = true;
try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
addInput.add(0, newIds);
}
return;
}

Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
for (int i = 0; i < idMap.size(); i++) {
int fromId = idMap.containsKey(0) ? 0 : 1;
int toId = fromId + idMap.size();
for (int i = fromId; i < toId; i++) {
newIdsBuilder.appendInt(idMap.get(i));
}
try (IntBlock newIds = newIdsBuilder.build()) {
Expand All @@ -59,10 +69,15 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
Map<Integer, Integer> idMap = new HashMap<>();
try (StreamInput in = new BytesArray(bytes).streamInput()) {
if (in.readBoolean()) {
seenNull = true;
idMap.put(NULL_ORD, NULL_ORD);
}
int count = in.readVInt();
for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
idMap.put(oldCategoryId, newCategoryId);
// +1 because the 0 ordinal is reserved for null
idMap.put(oldCategoryId + 1, newCategoryId + 1);
}
return idMap;
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {

public void testCategorizeRaw() {
final Page page;
final int positions = 7;
boolean withNull = randomBoolean();
final int positions = 7 + (withNull ? 1 : 0);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
Expand All @@ -61,6 +62,13 @@ public void testCategorizeRaw() {
builder.appendBytesRef(new BytesRef("Disconnected"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
if (withNull) {
if (randomBoolean()) {
builder.appendNull();
} else {
builder.appendBytesRef(new BytesRef(""));
}
}
page = new Page(builder.build());
}

Expand All @@ -70,13 +78,16 @@ public void testCategorizeRaw() {
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions);

assertEquals(0, groupIds.getInt(0));
assertEquals(1, groupIds.getInt(1));
assertEquals(1, groupIds.getInt(2));
assertEquals(1, groupIds.getInt(3));
assertEquals(2, groupIds.getInt(4));
assertEquals(0, groupIds.getInt(5));
assertEquals(0, groupIds.getInt(6));
assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

@Override
Expand All @@ -100,7 +111,8 @@ public void close() {

public void testCategorizeIntermediate() {
Page page1;
int positions1 = 7;
boolean withNull = randomBoolean();
int positions1 = 7 + (withNull ? 1 : 0);
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
builder.appendBytesRef(new BytesRef("Connection error"));
Expand All @@ -109,6 +121,13 @@ public void testCategorizeIntermediate() {
builder.appendBytesRef(new BytesRef("Connection error"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
if (withNull) {
if (randomBoolean()) {
builder.appendNull();
} else {
builder.appendBytesRef(new BytesRef(""));
}
}
page1 = new Page(builder.build());
}
Page page2;
Expand All @@ -133,13 +152,16 @@ public void testCategorizeIntermediate() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions1);
assertEquals(0, groupIds.getInt(0));
assertEquals(1, groupIds.getInt(1));
assertEquals(1, groupIds.getInt(2));
assertEquals(0, groupIds.getInt(3));
assertEquals(1, groupIds.getInt(4));
assertEquals(0, groupIds.getInt(5));
assertEquals(0, groupIds.getInt(6));
assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(2, groupIds.getInt(2));
assertEquals(1, groupIds.getInt(3));
assertEquals(2, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(5));
assertEquals(1, groupIds.getInt(6));
if (withNull) {
assertEquals(0, groupIds.getInt(7));
}
}

@Override
Expand All @@ -158,11 +180,11 @@ public void close() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
assertEquals(groupIds.getPositionCount(), positions2);
assertEquals(0, groupIds.getInt(0));
assertEquals(1, groupIds.getInt(1));
assertEquals(0, groupIds.getInt(2));
assertEquals(1, groupIds.getInt(3));
assertEquals(2, groupIds.getInt(4));
assertEquals(1, groupIds.getInt(0));
assertEquals(2, groupIds.getInt(1));
assertEquals(1, groupIds.getInt(2));
assertEquals(2, groupIds.getInt(3));
assertEquals(3, groupIds.getInt(4));
}

@Override
Expand All @@ -189,7 +211,11 @@ public void add(int positionOffset, IntBlock groupIds) {
.map(groupIds::getInt)
.boxed()
.collect(Collectors.toSet());
assertEquals(values, Set.of(0, 1));
if (withNull) {
assertEquals(Set.of(0, 1, 2), values);
} else {
assertEquals(Set.of(1, 2), values);
}
}

@Override
Expand All @@ -212,7 +238,7 @@ public void add(int positionOffset, IntBlock groupIds) {
.collect(Collectors.toSet());
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
// 0 matches an existing category (Connected to ...), and the others are new.
assertEquals(values, Set.of(0, 2, 3));
assertEquals(Set.of(1, 3, 4), values);
}

@Override
Expand Down
Loading