Skip to content

Commit 4360eb2

Browse files
committed
Using record ID as index value when parsing Google Vertex AI rerank results
1 parent 573b8a9 commit 4360eb2

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import org.elasticsearch.xcontent.XContentType;
1818
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
1919
import org.elasticsearch.xpack.inference.external.http.HttpResult;
20+
import org.elasticsearch.xpack.inference.external.request.Request;
21+
import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiRerankRequest;
2022

2123
import java.io.IOException;
2224
import java.util.List;
@@ -109,14 +111,19 @@ private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser)
109111
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
110112
}
111113

112-
return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content);
114+
if (parsedRankedDoc.id == null) {
115+
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.ID.getPreferredName()));
116+
}
117+
118+
return new RankedDocsResults.RankedDoc(Integer.parseInt(parsedRankedDoc.id), parsedRankedDoc.score, parsedRankedDoc.content);
113119
});
114120
}
115121

116-
private record RankedDoc(@Nullable Float score, @Nullable String content) {
122+
private record RankedDoc(@Nullable Float score, @Nullable String content, @Nullable String id) {
117123

118124
private static final ParseField CONTENT = new ParseField("content");
119125
private static final ParseField SCORE = new ParseField("score");
126+
private static final ParseField ID = new ParseField("id");
120127
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
121128
"google_vertex_ai_rerank_response",
122129
true,
@@ -126,6 +133,7 @@ private record RankedDoc(@Nullable Float score, @Nullable String content) {
126133
static {
127134
PARSER.declareString(Builder::setContent, CONTENT);
128135
PARSER.declareFloat(Builder::setScore, SCORE);
136+
PARSER.declareString(Builder::setId, ID);
129137
}
130138

131139
public static RankedDoc parse(XContentParser parser) {
@@ -137,6 +145,7 @@ private static final class Builder {
137145

138146
private String content;
139147
private Float score;
148+
private String id;
140149

141150
private Builder() {}
142151

@@ -150,8 +159,13 @@ public Builder setContent(String content) {
150159
return this;
151160
}
152161

162+
public Builder setId(String id) {
163+
this.id = id;
164+
return this;
165+
}
166+
153167
public RankedDoc build() {
154-
return new RankedDoc(score, content);
168+
return new RankedDoc(score, content, id);
155169
}
156170
}
157171
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
3939
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
4040
);
4141

42-
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"))));
42+
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
4343
}
4444

4545
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -68,7 +68,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
6868

6969
assertThat(
7070
parsedResults.getRankedDocs(),
71-
is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
71+
is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
7272
);
7373
}
7474

0 commit comments

Comments
 (0)