Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/118375.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118375
summary: Check for presence of error object when validating streaming responses from integrations in the inference API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@

package org.elasticsearch.xpack.inference.external.alibabacloudsearch;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.alibabacloudsearch.AlibabaCloudSearchErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

/**
* Defines how to handle various errors returned from the AlibabaCloudSearch integration.
Expand All @@ -28,21 +24,15 @@ public AlibabaCloudSearchResponseHandler(String requestType, ResponseParser pars
super(requestType, parseFunction, AlibabaCloudSearchErrorResponseEntity::fromResponse);
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move up into BaseResponseHandler.

throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

/**
* Validates the status code throws an RetryException if not in the range [200, 300).
*
* @param request The http request
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
*/
void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
if (RestStatus.isSuccessful(statusCode)) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.anthropic;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
Expand All @@ -19,11 +18,9 @@
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;

public class AnthropicResponseHandler extends BaseResponseHandler {
Expand Down Expand Up @@ -54,13 +51,6 @@ public AnthropicResponseHandler(String requestType, ResponseParser parseFunction
this.canHandleStreamingResponses = canHandleStreamingResponses;
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
Expand All @@ -83,7 +73,8 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
*/
void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.cohere;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
Expand All @@ -17,12 +16,9 @@
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.NewlineDelimitedByteProcessor;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

/**
* Defines how to handle various errors returned from the Cohere integration.
*
Expand All @@ -45,13 +41,6 @@ public CohereResponseHandler(String requestType, ResponseParser parseFunction, b
this.canHandleStreamingResponse = canHandleStreamingResponse;
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponse;
Expand All @@ -73,7 +62,8 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
*/
void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@

package org.elasticsearch.xpack.inference.external.elastic;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class ElasticInferenceServiceResponseHandler extends BaseResponseHandler {

Expand All @@ -26,13 +22,7 @@ public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.googleaistudio;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -20,13 +19,11 @@
import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioErrorResponseEntity;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.io.IOException;
import java.util.concurrent.Flow;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class GoogleAiStudioResponseHandler extends BaseResponseHandler {

Expand All @@ -52,13 +49,6 @@ public GoogleAiStudioResponseHandler(
this.content = content;
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

/**
* Validates the status code and throws a RetryException if not in the range [200, 300).
*
Expand All @@ -67,7 +57,8 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
*/
void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@

package org.elasticsearch.xpack.inference.external.googlevertexai;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.googlevertexai.GoogleVertexAiErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class GoogleVertexAiResponseHandler extends BaseResponseHandler {

Expand All @@ -28,13 +25,7 @@ public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFun
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@

package org.elasticsearch.xpack.inference.external.http.retry;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.Objects;
import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public abstract class BaseResponseHandler implements ResponseHandler {

Expand All @@ -27,14 +31,15 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String REDIRECTION = "Unhandled redirection";
public static final String CONTENT_TOO_LARGE = "Received a content too large status code";
public static final String UNSUCCESSFUL = "Received an unsuccessful status code";
public static final String SERVER_ERROR_OBJECT = "Received an error response";
public static final String BAD_REQUEST = "Received a bad request status code";
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
private final Function<HttpResult, ErrorMessage> errorParseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;

public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorMessage> errorParseFunction) {
public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
this.requestType = Objects.requireNonNull(requestType);
this.parseFunction = Objects.requireNonNull(parseFunction);
this.errorParseFunction = Objects.requireNonNull(errorParseFunction);
Expand All @@ -54,11 +59,42 @@ public String getRequestType() {
return requestType;
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);

// When the response is streamed the status code could be 200 but the error object will be set
// so we need to check for that specifically
checkForErrorObject(request, result);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have one potential concern for this. This will be executed for non-streaming and streaming code paths. So if for some reason we get a 200 response back (or some other non failure status code) and the response object is a valid response but also happens to have a field that the error object expects (which depends on each service implementation) then this would fail. I doubt that would happen.

If we were concerned about it we could create a new method validateStreamingResponse and only call checkForErrorObject in that method.

}

protected abstract void checkForFailureStatusCode(Request request, HttpResult result);

private void checkForErrorObject(Request request, HttpResult result) {
var errorEntity = errorParseFunction.apply(result);

if (errorEntity.errorStructureFound()) {
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
// client retry if necessary
// If we did want to retry here, we'll need to determine if this was a streaming request, if it was
// we shouldn't retry because that would replay the entire streaming request and the client would get
// duplicate chunks back
throw new RetryException(false, buildError(SERVER_ERROR_OBJECT, request, result, errorEntity));
}
}

protected Exception buildError(String message, Request request, HttpResult result) {
var errorEntityMsg = errorParseFunction.apply(result);
return buildError(message, request, result, errorEntityMsg);
}

protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
var responseStatusCode = result.response().getStatusLine().getStatusCode();

if (errorEntityMsg == null) {
if (errorResponse == null
|| errorResponse.errorStructureFound() == false
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage())) {
return new ElasticsearchStatusException(
format(
"%s for request from inference entity id [%s] status [%s]",
Expand All @@ -76,7 +112,7 @@ protected Exception buildError(String message, Request request, HttpResult resul
message,
request.getInferenceEntityId(),
responseStatusCode,
errorEntityMsg.getErrorMessage()
errorResponse.getErrorMessage()
),
toRestStatus(responseStatusCode)
);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.retry;

import java.util.Objects;

public class ErrorResponse {

// Denotes an error object that was not found
public static final ErrorResponse UNDEFINED_ERROR = new ErrorResponse(false);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I need this is to differentiate between an error object being present but a specific field in the object missing (for example message) and when no error exists. Previously we were returning null in both of those scenarios. Now if there's an exception or the main field that defines an error object doesn't exist we'll return this instance to indicate that there was no error present.

If a nested field is missing we return a created instance with errorStructureFound set to true.


private final String errorMessage;
private final boolean errorStructureFound;

public ErrorResponse(String errorMessage) {
this.errorMessage = Objects.requireNonNull(errorMessage);
this.errorStructureFound = true;
}

private ErrorResponse(boolean errorStructureFound) {
this.errorMessage = "";
this.errorStructureFound = errorStructureFound;
}

public String getErrorMessage() {
return errorMessage;
}

public boolean errorStructureFound() {
return errorStructureFound;
}
}
Loading