Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5412548
Adding custom headers support openai text embeddings
jonathan-buttner Sep 17, 2025
579eb18
Update docs/changelog/134960.yaml
jonathan-buttner Sep 17, 2025
8277bb7
Adding headers to the service api result
jonathan-buttner Sep 17, 2025
a9eda90
[CI] Auto commit changes from spotless
Sep 17, 2025
f52efab
Merge branch 'main' of github.com:elastic/elasticsearch into ml-opena…
jonathan-buttner Sep 18, 2025
2894254
Merge branch 'ml-openai-headers-embedding' of github.com:jonathan-but…
jonathan-buttner Sep 18, 2025
fc3457b
Merge branch 'main' into ml-openai-headers-embedding
jonathan-buttner Sep 18, 2025
4890f2d
Addressing feedback
jonathan-buttner Sep 22, 2025
2428b29
Merge branch 'main' of github.com:elastic/elasticsearch into ml-opena…
jonathan-buttner Sep 22, 2025
6f9940a
Adding transport version change
jonathan-buttner Sep 22, 2025
a8087ac
[CI] Auto commit changes from spotless
Sep 22, 2025
d79dc7d
Cleaning up helpers
jonathan-buttner Sep 22, 2025
bd63b49
Merge branch 'ml-openai-headers-embedding' of github.com:jonathan-but…
jonathan-buttner Sep 22, 2025
41bea77
[CI] Auto commit changes from spotless
Sep 22, 2025
881d97a
Merge branch 'main' of github.com:elastic/elasticsearch into ml-opena…
jonathan-buttner Sep 22, 2025
a930d0c
Fixing transport version
jonathan-buttner Sep 22, 2025
a29f5c4
Merge branch 'ml-openai-headers-embedding' of github.com:jonathan-but…
jonathan-buttner Sep 22, 2025
df8d8fe
Merge branch 'main' into ml-openai-headers-embedding
jonathan-buttner Sep 23, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/134960.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 134960
summary: Adding custom headers support openai text embeddings
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9169000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
security_stats_endpoint,9168000
inference_api_openai_embeddings_headers,9169000
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,8 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
HEADERS,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)).setDescription(
"Custom headers to include in the requests to OpenAI."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
.setDescription("Custom headers to include in the requests to OpenAI.")
.setLabel("Custom Headers")
.setRequired(false)
.setSensitive(false)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.openai;

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;

public abstract class OpenAiTaskSettings<T extends OpenAiTaskSettings<T>> implements TaskSettings {
private static final Settings EMPTY_SETTINGS = new Settings(null, null);

private final Settings taskSettings;

public OpenAiTaskSettings(Map<String, Object> map) {
this(fromMap(map));
}

public record Settings(@Nullable String user, @Nullable Map<String, String> headers) {}

public static Settings createSettings(String user, Map<String, String> stringHeaders) {
if (user == null && stringHeaders == null) {
return EMPTY_SETTINGS;
} else {
return new Settings(user, stringHeaders);
}
}

private static Settings fromMap(Map<String, Object> map) {
if (map.isEmpty()) {
return EMPTY_SETTINGS;
}

ValidationException validationException = new ValidationException();

String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
Map<String, Object> headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return createSettings(user, stringHeaders);
}

public OpenAiTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
this(new Settings(user, headers));
}

protected OpenAiTaskSettings(Settings taskSettings) {
this.taskSettings = Objects.requireNonNull(taskSettings);
}

public String user() {
return taskSettings.user();
}

public Map<String, String> headers() {
return taskSettings.headers();
}

@Override
public boolean isEmpty() {
return taskSettings.user() == null && (taskSettings.headers() == null || taskSettings.headers().isEmpty());
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

if (taskSettings.user() != null) {
builder.field(USER, taskSettings.user());
}

if (taskSettings.headers() != null && taskSettings.headers().isEmpty() == false) {
builder.field(HEADERS, taskSettings.headers());
}

builder.endObject();

return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OpenAiTaskSettings<?> that = (OpenAiTaskSettings<?>) o;
return Objects.equals(taskSettings, that.taskSettings);
}

@Override
public int hashCode() {
return Objects.hash(taskSettings);
}

@Override
public T updatedTaskSettings(Map<String, Object> newSettings) {
Settings updatedSettings = fromMap(new HashMap<>(newSettings));

var userToUse = updatedSettings.user() == null ? taskSettings.user() : updatedSettings.user();
var headersToUse = updatedSettings.headers() == null ? taskSettings.headers() : updatedSettings.headers();
return create(userToUse, headersToUse);
}

protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers);

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map<
return model;
}

var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings);
return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
return new OpenAiChatCompletionModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings));
}

public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) {
Expand Down Expand Up @@ -73,7 +72,7 @@ public OpenAiChatCompletionModel(
taskType,
service,
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
new OpenAiChatCompletionTaskSettings(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,100 +9,44 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettings;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;

public class OpenAiChatCompletionTaskSettings implements TaskSettings {
public class OpenAiChatCompletionTaskSettings extends OpenAiTaskSettings<OpenAiChatCompletionTaskSettings> {

public static final String NAME = "openai_completion_task_settings";

public static OpenAiChatCompletionTaskSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
var headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new OpenAiChatCompletionTaskSettings(user, stringHeaders);
public OpenAiChatCompletionTaskSettings(Map<String, Object> map) {
super(map);
}

private final String user;
@Nullable
private final Map<String, String> headers;

public OpenAiChatCompletionTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
this.user = user;
this.headers = headers;
super(user, headers);
}

public OpenAiChatCompletionTaskSettings(StreamInput in) throws IOException {
this.user = in.readOptionalString();
super(readTaskSettingsFromStream(in));
}

private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException {
var user = in.readOptionalString();

Map<String, String> headers;

if (in.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString);
} else {
headers = null;
}
}

@Override
public boolean isEmpty() {
return user == null && (headers == null || headers.isEmpty());
}

public static OpenAiChatCompletionTaskSettings of(
OpenAiChatCompletionTaskSettings originalSettings,
OpenAiChatCompletionRequestTaskSettings requestSettings
) {
var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
var headersToUse = requestSettings.headers() == null ? originalSettings.headers : requestSettings.headers();
return new OpenAiChatCompletionTaskSettings(userToUse, headersToUse);
}

public String user() {
return user;
}

public Map<String, String> headers() {
return headers;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

if (user != null) {
builder.field(USER, user);
}

if (headers != null && headers.isEmpty() == false) {
builder.field(HEADERS, headers);
}

builder.endObject();

return builder;
return createSettings(user, headers);
}

@Override
Expand All @@ -117,30 +61,14 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(user);
out.writeOptionalString(user());
if (out.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
out.writeOptionalMap(headers, StreamOutput::writeString, StreamOutput::writeString);
out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString);
}
}

@Override
public boolean equals(Object object) {
if (this == object) return true;
if (object == null || getClass() != object.getClass()) return false;
OpenAiChatCompletionTaskSettings that = (OpenAiChatCompletionTaskSettings) object;
return Objects.equals(user, that.user) && Objects.equals(headers, that.headers);
}

@Override
public int hashCode() {
return Objects.hash(user, headers);
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
OpenAiChatCompletionRequestTaskSettings updatedSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(
new HashMap<>(newSettings)
);
return of(this, updatedSettings);
protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map<String, String> headers) {
return new OpenAiChatCompletionTaskSettings(user, headers);
}
}
Loading