Skip to content

Commit 245f5ee

Browse files
authored
[ML] Integrate SageMaker with OpenAI Embeddings (#126856)
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
1 parent 1b35cce commit 245f5ee

39 files changed

+4311
-146
lines changed

docs/changelog/126856.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126856
2+
summary: "[ML] Integrate SageMaker with OpenAI Embeddings"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

gradle/verification-metadata.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4912,6 +4912,11 @@
49124912
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
49134913
</artifact>
49144914
</component>
4915+
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
4916+
<artifact name="sagemakerruntime-2.30.38.jar">
4917+
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
4918+
</artifact>
4919+
</component>
49154920
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
49164921
<artifact name="sdk-core-2.30.38.jar">
49174922
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ static TransportVersion def(int id) {
162162
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
163163
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
164164
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
165+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
165166
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
166167
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
167168
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
@@ -232,6 +233,7 @@ static TransportVersion def(int id) {
232233
public static final TransportVersion PROJECT_METADATA_SETTINGS = def(9_066_00_0);
233234
public static final TransportVersion AGGREGATE_METRIC_DOUBLE_BLOCK = def(9_067_00_0);
234235
public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00);
236+
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00);
235237

236238
/*
237239
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/common/ValidationException.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ public final List<String> validationErrors() {
5353
return validationErrors;
5454
}
5555

56+
public final void throwIfValidationErrorsExist() {
57+
if (validationErrors().isEmpty() == false) {
58+
throw this;
59+
}
60+
}
61+
5662
@Override
5763
public final String getMessage() {
5864
StringBuilder sb = new StringBuilder();

x-pack/plugin/inference/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ dependencies {
6262

6363
/* AWS SDK v2 */
6464
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
65+
implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
6566
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
6667
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
6768
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
@@ -142,6 +143,7 @@ tasks.named("dependencyLicenses").configure {
142143
mapping from: /json-utils.*/, to: 'aws-sdk-2'
143144
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
144145
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
146+
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
145147
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
146148
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
147149
mapping from: /netty-buffer/, to: 'netty'

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 115 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -18,163 +18,161 @@
1818
import java.util.Map;
1919

2020
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
21+
import static org.hamcrest.Matchers.containsInAnyOrder;
2122
import static org.hamcrest.Matchers.equalTo;
2223

2324
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2425

25-
@SuppressWarnings("unchecked")
2626
public void testGetServicesWithoutTaskType() throws IOException {
2727
List<Object> services = getAllServices();
28-
assertThat(services.size(), equalTo(21));
29-
30-
String[] providers = new String[services.size()];
31-
for (int i = 0; i < services.size(); i++) {
32-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
33-
providers[i] = (String) serviceConfig.get("service");
34-
}
35-
36-
assertArrayEquals(
37-
List.of(
38-
"alibabacloud-ai-search",
39-
"amazonbedrock",
40-
"anthropic",
41-
"azureaistudio",
42-
"azureopenai",
43-
"cohere",
44-
"deepseek",
45-
"elastic",
46-
"elasticsearch",
47-
"googleaistudio",
48-
"googlevertexai",
49-
"hugging_face",
50-
"jinaai",
51-
"mistral",
52-
"openai",
53-
"streaming_completion_test_service",
54-
"test_reranking_service",
55-
"test_service",
56-
"text_embedding_test_service",
57-
"voyageai",
58-
"watsonxai"
59-
).toArray(),
60-
providers
28+
assertThat(services.size(), equalTo(22));
29+
30+
var providers = providers(services);
31+
32+
assertThat(
33+
providers,
34+
containsInAnyOrder(
35+
List.of(
36+
"alibabacloud-ai-search",
37+
"amazonbedrock",
38+
"anthropic",
39+
"azureaistudio",
40+
"azureopenai",
41+
"cohere",
42+
"deepseek",
43+
"elastic",
44+
"elasticsearch",
45+
"googleaistudio",
46+
"googlevertexai",
47+
"hugging_face",
48+
"jinaai",
49+
"mistral",
50+
"openai",
51+
"streaming_completion_test_service",
52+
"test_reranking_service",
53+
"test_service",
54+
"text_embedding_test_service",
55+
"voyageai",
56+
"watsonxai",
57+
"sagemaker"
58+
).toArray()
59+
)
6160
);
6261
}
6362

6463
@SuppressWarnings("unchecked")
64+
private Iterable<String> providers(List<Object> services) {
65+
return services.stream().map(service -> {
66+
var serviceConfig = (Map<String, Object>) service;
67+
return (String) serviceConfig.get("service");
68+
}).toList();
69+
}
70+
6571
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6672
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
67-
assertThat(services.size(), equalTo(15));
68-
69-
String[] providers = new String[services.size()];
70-
for (int i = 0; i < services.size(); i++) {
71-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
72-
providers[i] = (String) serviceConfig.get("service");
73-
}
74-
75-
assertArrayEquals(
76-
List.of(
77-
"alibabacloud-ai-search",
78-
"amazonbedrock",
79-
"azureaistudio",
80-
"azureopenai",
81-
"cohere",
82-
"elasticsearch",
83-
"googleaistudio",
84-
"googlevertexai",
85-
"hugging_face",
86-
"jinaai",
87-
"mistral",
88-
"openai",
89-
"text_embedding_test_service",
90-
"voyageai",
91-
"watsonxai"
92-
).toArray(),
93-
providers
73+
assertThat(services.size(), equalTo(16));
74+
75+
var providers = providers(services);
76+
77+
assertThat(
78+
providers,
79+
containsInAnyOrder(
80+
List.of(
81+
"alibabacloud-ai-search",
82+
"amazonbedrock",
83+
"azureaistudio",
84+
"azureopenai",
85+
"cohere",
86+
"elasticsearch",
87+
"googleaistudio",
88+
"googlevertexai",
89+
"hugging_face",
90+
"jinaai",
91+
"mistral",
92+
"openai",
93+
"text_embedding_test_service",
94+
"voyageai",
95+
"watsonxai",
96+
"sagemaker"
97+
).toArray()
98+
)
9499
);
95100
}
96101

97-
@SuppressWarnings("unchecked")
98102
public void testGetServicesWithRerankTaskType() throws IOException {
99103
List<Object> services = getServices(TaskType.RERANK);
100104
assertThat(services.size(), equalTo(7));
101105

102-
String[] providers = new String[services.size()];
103-
for (int i = 0; i < services.size(); i++) {
104-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
105-
providers[i] = (String) serviceConfig.get("service");
106-
}
107-
108-
assertArrayEquals(
109-
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
110-
.toArray(),
111-
providers
106+
var providers = providers(services);
107+
108+
assertThat(
109+
providers,
110+
containsInAnyOrder(
111+
List.of(
112+
"alibabacloud-ai-search",
113+
"cohere",
114+
"elasticsearch",
115+
"googlevertexai",
116+
"jinaai",
117+
"test_reranking_service",
118+
"voyageai"
119+
).toArray()
120+
)
112121
);
113122
}
114123

115-
@SuppressWarnings("unchecked")
116124
public void testGetServicesWithCompletionTaskType() throws IOException {
117125
List<Object> services = getServices(TaskType.COMPLETION);
118126
assertThat(services.size(), equalTo(10));
119127

120-
String[] providers = new String[services.size()];
121-
for (int i = 0; i < services.size(); i++) {
122-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
123-
providers[i] = (String) serviceConfig.get("service");
124-
}
125-
126-
assertArrayEquals(
127-
List.of(
128-
"alibabacloud-ai-search",
129-
"amazonbedrock",
130-
"anthropic",
131-
"azureaistudio",
132-
"azureopenai",
133-
"cohere",
134-
"deepseek",
135-
"googleaistudio",
136-
"openai",
137-
"streaming_completion_test_service"
138-
).toArray(),
139-
providers
128+
var providers = providers(services);
129+
130+
assertThat(
131+
providers,
132+
containsInAnyOrder(
133+
List.of(
134+
"alibabacloud-ai-search",
135+
"amazonbedrock",
136+
"anthropic",
137+
"azureaistudio",
138+
"azureopenai",
139+
"cohere",
140+
"deepseek",
141+
"googleaistudio",
142+
"openai",
143+
"streaming_completion_test_service"
144+
).toArray()
145+
)
140146
);
141147
}
142148

143-
@SuppressWarnings("unchecked")
144149
public void testGetServicesWithChatCompletionTaskType() throws IOException {
145150
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
146151
assertThat(services.size(), equalTo(4));
147152

148-
String[] providers = new String[services.size()];
149-
for (int i = 0; i < services.size(); i++) {
150-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
151-
providers[i] = (String) serviceConfig.get("service");
152-
}
153+
var providers = providers(services);
153154

154-
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
155+
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
155156
}
156157

157-
@SuppressWarnings("unchecked")
158158
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
159159
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
160160
assertThat(services.size(), equalTo(6));
161161

162-
String[] providers = new String[services.size()];
163-
for (int i = 0; i < services.size(); i++) {
164-
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
165-
providers[i] = (String) serviceConfig.get("service");
166-
}
167-
168-
assertArrayEquals(
169-
List.of(
170-
"alibabacloud-ai-search",
171-
"elastic",
172-
"elasticsearch",
173-
"hugging_face",
174-
"streaming_completion_test_service",
175-
"test_service"
176-
).toArray(),
177-
providers
162+
var providers = providers(services);
163+
164+
assertThat(
165+
providers,
166+
containsInAnyOrder(
167+
List.of(
168+
"alibabacloud-ai-search",
169+
"elastic",
170+
"elasticsearch",
171+
"hugging_face",
172+
"streaming_completion_test_service",
173+
"test_service"
174+
).toArray()
175+
)
178176
);
179177
}
180178

x-pack/plugin/inference/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
requires org.elasticsearch.logging;
3737
requires org.elasticsearch.sslconfig;
3838
requires org.apache.commons.text;
39+
requires software.amazon.awssdk.services.sagemakerruntime;
3940

4041
exports org.elasticsearch.xpack.inference.action;
4142
exports org.elasticsearch.xpack.inference.registry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
9393
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
9494
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
95+
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
96+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
9597
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
9698
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
9799
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
@@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
157159

158160
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
159161
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
162+
namedWriteables.addAll(SageMakerModel.namedWriteables());
163+
namedWriteables.addAll(SageMakerSchemas.namedWriteables());
160164

161165
return namedWriteables;
162166
}

0 commit comments

Comments
 (0)