Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8be20f0
RRFRetrieverComponent added:
mridula-s109 Jul 4, 2025
a8f6487
Modified parser, toXcontent and included component in the RetrieverBu…
mridula-s109 Jul 4, 2025
e07c38d
[CI] Auto commit changes from spotless
Jul 4, 2025
33d3da4
Resolved merge conflicts
mridula-s109 Jul 15, 2025
5fb5568
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 15, 2025
3ba149c
Fixed compile issues in tests
mridula-s109 Jul 15, 2025
d5749f6
[CI] Auto commit changes from spotless
Jul 15, 2025
7614936
trying to resolve parse errros
mridula-s109 Jul 16, 2025
a5d9e34
wip
ioanatia Jul 17, 2025
0640099
Modified builder
mridula-s109 Jul 17, 2025
cec23c2
[CI] Auto commit changes from spotless
Jul 17, 2025
6da9e15
Removed unnecessary code
mridula-s109 Jul 18, 2025
51b350e
Fixed import
mridula-s109 Jul 18, 2025
4050a3a
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 18, 2025
ea664eb
Enhanced tests
mridula-s109 Jul 18, 2025
98e72be
Fixed the failing tests
mridula-s109 Jul 21, 2025
7de8c7a
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 21, 2025
c778dd2
Yaml tests were added
mridula-s109 Jul 22, 2025
c7b331d
Added cluster features to it
mridula-s109 Jul 22, 2025
f543cbe
Fixed spotless
mridula-s109 Jul 22, 2025
75ab8d0
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 22, 2025
f5086c1
Update docs/changelog/130658.yaml
mridula-s109 Jul 22, 2025
fafb50f
Fixed the relaxed constraints
mridula-s109 Jul 23, 2025
e535864
Resolving issues
mridula-s109 Jul 23, 2025
78f8641
Resolved PR comments
mridula-s109 Jul 23, 2025
02647b1
removed simplified rrf
mridula-s109 Jul 23, 2025
2010f3a
changed the test file back to its original state
mridula-s109 Jul 24, 2025
7433023
Resolved comments to have ahelper method and the test case to use it
mridula-s109 Jul 24, 2025
a2bf4de
made parsing robust
mridula-s109 Jul 24, 2025
eebf577
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 24, 2025
0388abd
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 24, 2025
74ed8db
IT test reverted
mridula-s109 Jul 24, 2025
6d7e8ff
Replaced the declareString array parser
mridula-s109 Jul 25, 2025
f1e14ce
Enforced weights as nonnull
mridula-s109 Jul 25, 2025
fd30387
Fixed the weights null
mridula-s109 Jul 25, 2025
3a82a28
Empty weight shouldnt be serialised
mridula-s109 Jul 25, 2025
77c14d3
[CI] Auto commit changes from spotless
Jul 25, 2025
45ca068
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 25, 2025
2e121b0
removed the hard coding
mridula-s109 Jul 28, 2025
d66f5a6
Cleanup and optimised the code flow
mridula-s109 Jul 28, 2025
532e7df
Fixed the comments
mridula-s109 Jul 28, 2025
184330c
[CI] Auto commit changes from spotless
Jul 28, 2025
c43a075
optimised test
mridula-s109 Jul 28, 2025
e5f8079
Added additional test
mridula-s109 Jul 28, 2025
65cc528
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 28, 2025
4ac4940
addressed the commentS
mridula-s109 Jul 29, 2025
e6f22bc
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 29, 2025
3fc0d35
Update docs/changelog/130658.yaml
mridula-s109 Jul 31, 2025
5c364a0
Explicit check for retriever object
mridula-s109 Jul 31, 2025
978e182
Resolved PR comments
mridula-s109 Jul 31, 2025
032e946
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 31, 2025
1d807ee
Fixed the error message
mridula-s109 Jul 31, 2025
7f4c7cd
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 31, 2025
99f4ad2
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 31, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130658.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130658
summary: Implement support for weighted rrf
area: Relevance
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public Set<NodeFeature> getTestFeatures() {
LINEAR_RETRIEVER_L2_NORM,
LINEAR_RETRIEVER_MINSCORE_FIX,
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
RRFRetrieverBuilder.WEIGHTED_SUPPORT
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
Expand All @@ -37,7 +38,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;

/**
* An rrf retriever is used to represent an rrf rank element, but
Expand All @@ -48,6 +49,7 @@
*/
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");

public static final String NAME = "rrf";

Expand All @@ -57,37 +59,38 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
public static final ParseField QUERY_FIELD = new ParseField("query");

public static final int DEFAULT_RANK_CONSTANT = 60;

private final float[] weights;

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0];
List<String> fields = (List<String>) args[1];
String query = (String) args[2];
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];

List<RetrieverSource> innerRetrievers = childRetrievers != null
? childRetrievers.stream().map(RetrieverSource::from).toList()
: List.of();
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
int n = retrieverComponents.size();
List<RetrieverSource> innerRetrievers = new ArrayList<>(n);
float[] weights = new float[n];
for (int i = 0; i < n; i++) {
RRFRetrieverComponent component = retrieverComponents.get(i);
innerRetrievers.add(RetrieverSource.from(component.retriever()));
weights[i] = component.weight();
}
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
}
);

static {
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
p.nextToken();
String name = p.currentName();
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
c.trackRetrieverUsage(retrieverBuilder.getName());
p.nextToken();
return retrieverBuilder;
}, RETRIEVERS_FIELD);
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
RetrieverBuilder.declareBaseParserFields(PARSER);
}

Expand All @@ -103,27 +106,46 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
private final int rankConstant;

public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
this(childRetrievers, null, null, rankWindowSize, rankConstant);
this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers));
}

private static float[] createDefaultWeights(List<?> retrievers) {
int size = retrievers == null ? 0 : retrievers.size();
float[] defaultWeights = new float[size];
Arrays.fill(defaultWeights, DEFAULT_WEIGHT);
return defaultWeights;
}

public RRFRetrieverBuilder(
List<RetrieverSource> childRetrievers,
List<String> fields,
String query,
int rankWindowSize,
int rankConstant
int rankConstant,
float[] weights
) {
// Use a mutable list for childRetrievers so that we can use addChild
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
this.fields = fields == null ? null : List.copyOf(fields);
this.query = query;
this.rankConstant = rankConstant;
Objects.requireNonNull(weights, "weights must not be null");
if (weights.length != innerRetrievers.size()) {
throw new IllegalArgumentException(
"weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]"
);
}
this.weights = weights;
}

public int rankConstant() {
return rankConstant;
}

public float[] weights() {
return weights;
}

@Override
public String getName() {
return NAME;
Expand All @@ -137,6 +159,7 @@ public ActionRequestValidationException validate(
boolean allowPartialSearchResults
) {
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);

return MultiFieldsInnerRetrieverUtils.validateParams(
innerRetrievers,
fields,
Expand All @@ -151,7 +174,14 @@ public ActionRequestValidationException validate(

@Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(
newRetrievers,
this.fields,
this.query,
this.rankWindowSize,
this.rankConstant,
this.weights
);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
Expand Down Expand Up @@ -183,7 +213,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults

// calculate the current rrf score for this document
// later used to sort and covert to a rank
value.score += 1.0f / (rankConstant + frank);
value.score += this.weights[findex] * (1.0f / (rankConstant + frank));

if (explain && value.positions != null && value.scores != null) {
// record the position for each query
Expand Down Expand Up @@ -238,10 +268,14 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
query,
localIndicesMetadata.values(),
r -> {
List<RetrieverSource> retrievers = r.stream()
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
.toList();
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
List<RetrieverSource> retrievers = new ArrayList<>(r.size());
float[] weights = new float[r.size()];
for (int i = 0; i < r.size(); i++) {
var retriever = r.get(i);
retrievers.add(retriever.retrieverSource());
weights[i] = retriever.weight();
}
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
},
w -> {
if (w != 1.0f) {
Expand All @@ -255,7 +289,8 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
if (fieldsInnerRetrievers.isEmpty() == false) {
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
// (such as dropping the retriever name and min score)
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
float[] weights = createDefaultWeights(fieldsInnerRetrievers);
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, null, null, rankWindowSize, rankConstant, weights);
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
} else {
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
Expand All @@ -266,29 +301,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
return rewritten;
}

// ---- FOR TESTING XCONTENT PARSING ----

@Override
public boolean doEquals(Object o) {
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
return super.doEquals(o)
&& Objects.equals(fields, that.fields)
&& Objects.equals(query, that.query)
&& rankConstant == that.rankConstant;
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
}

@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
if (innerRetrievers.isEmpty() == false) {
builder.startArray(RETRIEVERS_FIELD.getPreferredName());

for (var entry : innerRetrievers) {
entry.retriever().toXContent(builder, params);
for (int i = 0; i < innerRetrievers.size(); i++) {
RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]);
component.toXContent(builder, params);
}
builder.endArray();
}
Expand All @@ -307,4 +326,20 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
}

// ---- FOR TESTING XCONTENT PARSING ----
@Override
public boolean doEquals(Object o) {
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
return super.doEquals(o)
&& Objects.equals(fields, that.fields)
&& Objects.equals(query, that.query)
&& rankConstant == that.rankConstant
&& Arrays.equals(weights, that.weights);
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.rank.rrf;

import org.elasticsearch.common.ParsingException;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

public class RRFRetrieverComponent implements ToXContentObject {

public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
static final float DEFAULT_WEIGHT = 1f;

final RetrieverBuilder retriever;
final float weight;

public RRFRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight) {
this.retriever = Objects.requireNonNull(retrieverBuilder, "retrieverBuilder must not be null");
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
if (this.weight < 0) {
throw new IllegalArgumentException("[weight] must be non-negative, found [" + this.weight + "]");
}
}

public RetrieverBuilder retriever() {
return retriever;
}

public float weight() {
return weight;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException {
builder.startObject();
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
builder.endObject();
return builder;
}

public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "expected object but found [{}]", parser.currentToken());
}

// Peek at the first field to determine the format
XContentParser.Token token = parser.nextToken();
if (token == XContentParser.Token.END_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever");
}
if (token != XContentParser.Token.FIELD_NAME) {
throw new ParsingException(parser.getTokenLocation(), "expected field name but found [{}]", token);
}

String firstFieldName = parser.currentName();

// Check if this is a structured component (starts with "retriever" or "weight")
if (RETRIEVER_FIELD.match(firstFieldName, parser.getDeprecationHandler())
|| WEIGHT_FIELD.match(firstFieldName, parser.getDeprecationHandler())) {
// This is a structured component - parse manually
RetrieverBuilder retriever = null;
Float weight = null;

do {
String fieldName = parser.currentName();
if (RETRIEVER_FIELD.match(fieldName, parser.getDeprecationHandler())) {
if (retriever != null) {
throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified");
}
parser.nextToken();
parser.nextToken();
String retrieverType = parser.currentName();
retriever = parser.namedObject(RetrieverBuilder.class, retrieverType, context);
context.trackRetrieverUsage(retriever.getName());
parser.nextToken();
} else if (WEIGHT_FIELD.match(fieldName, parser.getDeprecationHandler())) {
if (weight != null) {
throw new ParsingException(parser.getTokenLocation(), "[weight] field can only be specified once");
}
parser.nextToken();
weight = parser.floatValue();
} else {
throw new ParsingException(parser.getTokenLocation(), "unknown field [{}] after retriever", fieldName);
}
} while (parser.nextToken() == XContentParser.Token.FIELD_NAME);

if (retriever == null) {
throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever");
}

return new RRFRetrieverComponent(retriever, weight);
} else {
RetrieverBuilder retriever = parser.namedObject(RetrieverBuilder.class, firstFieldName, context);
context.trackRetrieverUsage(retriever.getName());
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "unknown field [{}] after retriever", parser.currentName());
}
return new RRFRetrieverComponent(retriever, DEFAULT_WEIGHT);
}
}
}
Loading