Skip to content

Commit 7e9a7c3

Browse files
jonathan-buttnerelasticsearchmachinedavidkyle
authored
[ML] Adding response parsers for custom service (#127179) (#127478)
* Adding response parsers for custom service * [CI] Auto commit changes from spotless * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java * Refactoring to include field names in exceptions * Adding list entry index to error message and field names * Addressing feedback for validation exception --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co> Co-authored-by: David Kyle <david.kyle@elastic.co>
1 parent 30d5a9d commit 7e9a7c3

18 files changed

+2654
-23
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import org.elasticsearch.common.Strings;
1111

1212
import java.util.ArrayList;
13+
import java.util.LinkedHashSet;
1314
import java.util.List;
1415
import java.util.Map;
16+
import java.util.Set;
1517
import java.util.regex.Pattern;
1618

1719
/**
@@ -78,6 +80,8 @@
7880
* [1, 2]
7981
* ]
8082
* }
83+
*
84+
* The array field names would be {@code ["embeddings", "embedding"}
8185
* </pre>
8286
*
8387
* This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array.
@@ -123,10 +127,28 @@ public class MapPathExtractor {
123127
private static final String DOLLAR = "$";
124128

125129
// default for testing
126-
static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)");
127-
static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)");
130+
static final Pattern DOT_FIELD_PATTERN = Pattern.compile("^\\.([^.\\[]+)(.*)");
131+
static final Pattern ARRAY_WILDCARD_PATTERN = Pattern.compile("^\\[\\*\\](.*)");
132+
public static final String UNKNOWN_FIELD_NAME = "unknown";
133+
134+
/**
135+
* A result object that tries to match up the field names parsed from the passed in path and the result
136+
* extracted from the passed in map.
137+
* @param extractedObject represents the extracted result from the map
138+
* @param traversedFields a list of field names in order as they're encountered while navigating through the nested objects
139+
*/
140+
public record Result(Object extractedObject, List<String> traversedFields) {
141+
public String getArrayFieldName(int index) {
142+
// if the index is out of bounds we'll return a default value
143+
if (traversedFields.size() <= index || index < 0) {
144+
return UNKNOWN_FIELD_NAME;
145+
}
146+
147+
return traversedFields.get(index);
148+
}
149+
}
128150

129-
public static Object extract(Map<String, Object> data, String path) {
151+
public static Result extract(Map<String, Object> data, String path) {
130152
if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) {
131153
return null;
132154
}
@@ -139,16 +161,41 @@ public static Object extract(Map<String, Object> data, String path) {
139161
throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath));
140162
}
141163

142-
return navigate(data, cleanedPath);
164+
var fieldNames = new LinkedHashSet<String>();
165+
166+
return new Result(navigate(data, cleanedPath, new FieldNameInfo("", "", fieldNames)), fieldNames.stream().toList());
143167
}
144168

145-
private static Object navigate(Object current, String remainingPath) {
146-
if (current == null || remainingPath == null || remainingPath.isEmpty()) {
169+
private record FieldNameInfo(String currentPath, String fieldName, Set<String> traversedFields) {
170+
void addTraversedField(String fieldName) {
171+
traversedFields.add(createPath(fieldName));
172+
}
173+
174+
void addCurrentField() {
175+
traversedFields.add(currentPath);
176+
}
177+
178+
FieldNameInfo descend(String newFieldName) {
179+
var newLocation = createPath(newFieldName);
180+
return new FieldNameInfo(newLocation, newFieldName, traversedFields);
181+
}
182+
183+
private String createPath(String newFieldName) {
184+
if (Strings.isNullOrEmpty(currentPath)) {
185+
return newFieldName;
186+
} else {
187+
return currentPath + "." + newFieldName;
188+
}
189+
}
190+
}
191+
192+
private static Object navigate(Object current, String remainingPath, FieldNameInfo fieldNameInfo) {
193+
if (current == null || Strings.isNullOrEmpty(remainingPath)) {
147194
return current;
148195
}
149196

150-
var dotFieldMatcher = dotFieldPattern.matcher(remainingPath);
151-
var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath);
197+
var dotFieldMatcher = DOT_FIELD_PATTERN.matcher(remainingPath);
198+
var arrayWildcardMatcher = ARRAY_WILDCARD_PATTERN.matcher(remainingPath);
152199

153200
if (dotFieldMatcher.matches()) {
154201
String field = dotFieldMatcher.group(1);
@@ -168,7 +215,12 @@ private static Object navigate(Object current, String remainingPath) {
168215
throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field));
169216
}
170217

171-
return navigate(currentMap.get(field), nextPath);
218+
// Handle the case where the path was $.result.text or $.result[*].key
219+
if (Strings.isNullOrEmpty(nextPath)) {
220+
fieldNameInfo.addTraversedField(field);
221+
}
222+
223+
return navigate(currentMap.get(field), nextPath, fieldNameInfo.descend(field));
172224
} else {
173225
throw new IllegalArgumentException(
174226
Strings.format(
@@ -182,10 +234,12 @@ private static Object navigate(Object current, String remainingPath) {
182234
} else if (arrayWildcardMatcher.matches()) {
183235
String nextPath = arrayWildcardMatcher.group(1);
184236
if (current instanceof List<?> list) {
237+
fieldNameInfo.addCurrentField();
238+
185239
List<Object> results = new ArrayList<>();
186240

187241
for (Object item : list) {
188-
Object result = navigate(item, nextPath);
242+
Object result = navigate(item, nextPath, fieldNameInfo);
189243
if (result != null) {
190244
results.add(result);
191245
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,16 @@ public String getErrorMessage() {
3434
public boolean errorStructureFound() {
3535
return errorStructureFound;
3636
}
37+
38+
@Override
39+
public boolean equals(Object o) {
40+
if (o == null || getClass() != o.getClass()) return false;
41+
ErrorResponse that = (ErrorResponse) o;
42+
return errorStructureFound == that.errorStructureFound && Objects.equals(errorMessage, that.errorMessage);
43+
}
44+
45+
@Override
46+
public int hashCode() {
47+
return Objects.hash(errorMessage, errorStructureFound);
48+
}
3749
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.custom;
9+
10+
public class CustomServiceSettings {
11+
public static final String NAME = "custom_service_settings";
12+
public static final String URL = "url";
13+
public static final String HEADERS = "headers";
14+
public static final String REQUEST = "request";
15+
public static final String REQUEST_CONTENT = "content";
16+
public static final String RESPONSE = "response";
17+
public static final String JSON_PARSER = "json_parser";
18+
public static final String ERROR_PARSER = "error_parser";
19+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.custom.response;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xcontent.XContentFactory;
13+
import org.elasticsearch.xcontent.XContentParser;
14+
import org.elasticsearch.xcontent.XContentParserConfiguration;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
17+
18+
import java.io.IOException;
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
import java.util.function.BiFunction;
24+
25+
public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser {
26+
27+
@Override
28+
public InferenceServiceResults parse(HttpResult response) throws IOException {
29+
try (
30+
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
31+
.createParser(XContentParserConfiguration.EMPTY, response.body())
32+
) {
33+
var map = jsonParser.map();
34+
35+
return transform(map);
36+
}
37+
}
38+
39+
protected abstract T transform(Map<String, Object> extractedField);
40+
41+
static List<?> validateList(Object obj, String fieldName) {
42+
validateNonNull(obj, fieldName);
43+
44+
if (obj instanceof List<?> == false) {
45+
throw new IllegalArgumentException(
46+
Strings.format(
47+
"Extracted field [%s] is an invalid type, expected a list but received [%s]",
48+
fieldName,
49+
obj.getClass().getSimpleName()
50+
)
51+
);
52+
}
53+
54+
return (List<?>) obj;
55+
}
56+
57+
static void validateNonNull(Object obj, String fieldName) {
58+
Objects.requireNonNull(obj, Strings.format("Failed to parse field [%s], extracted field was null", fieldName));
59+
}
60+
61+
static Map<String, Object> validateMap(Object obj, String fieldName) {
62+
validateNonNull(obj, fieldName);
63+
64+
if (obj instanceof Map<?, ?> == false) {
65+
throw new IllegalArgumentException(
66+
Strings.format(
67+
"Extracted field [%s] is an invalid type, expected a map but received [%s]",
68+
fieldName,
69+
obj.getClass().getSimpleName()
70+
)
71+
);
72+
}
73+
74+
var keys = ((Map<?, ?>) obj).keySet();
75+
for (var key : keys) {
76+
if (key instanceof String == false) {
77+
throw new IllegalStateException(
78+
Strings.format(
79+
"Extracted field [%s] map has an invalid key type. Expected a string but received [%s]",
80+
fieldName,
81+
key.getClass().getSimpleName()
82+
)
83+
);
84+
}
85+
}
86+
87+
@SuppressWarnings("unchecked")
88+
var result = (Map<String, Object>) obj;
89+
return result;
90+
}
91+
92+
static List<Float> convertToListOfFloats(Object obj, String fieldName) {
93+
return castList(validateList(obj, fieldName), BaseCustomResponseParser::toFloat, fieldName);
94+
}
95+
96+
static Float toFloat(Object obj, String fieldName) {
97+
return toNumber(obj, fieldName).floatValue();
98+
}
99+
100+
private static Number toNumber(Object obj, String fieldName) {
101+
if (obj instanceof Number == false) {
102+
throw new IllegalArgumentException(
103+
Strings.format("Unable to convert field [%s] of type [%s] to Number", fieldName, obj.getClass().getSimpleName())
104+
);
105+
}
106+
107+
return ((Number) obj);
108+
}
109+
110+
static List<Integer> convertToListOfIntegers(Object obj, String fieldName) {
111+
return castList(validateList(obj, fieldName), BaseCustomResponseParser::toInteger, fieldName);
112+
}
113+
114+
private static Integer toInteger(Object obj, String fieldName) {
115+
return toNumber(obj, fieldName).intValue();
116+
}
117+
118+
static <T> List<T> castList(List<?> items, BiFunction<Object, String, T> converter, String fieldName) {
119+
validateNonNull(items, fieldName);
120+
121+
List<T> resultList = new ArrayList<>();
122+
for (int i = 0; i < items.size(); i++) {
123+
try {
124+
resultList.add(converter.apply(items.get(i), fieldName));
125+
} catch (Exception e) {
126+
throw new IllegalStateException(Strings.format("Failed to parse list entry [%d], error: %s", i, e.getMessage()), e);
127+
}
128+
}
129+
130+
return resultList;
131+
}
132+
133+
static <T> T toType(Object obj, Class<T> type, String fieldName) {
134+
validateNonNull(obj, fieldName);
135+
136+
if (type.isInstance(obj) == false) {
137+
throw new IllegalArgumentException(
138+
Strings.format(
139+
"Unable to convert field [%s] of type [%s] to [%s]",
140+
fieldName,
141+
obj.getClass().getSimpleName(),
142+
type.getSimpleName()
143+
)
144+
);
145+
}
146+
147+
return type.cast(obj);
148+
}
149+
}

0 commit comments

Comments
 (0)