Skip to content

Conversation

jonathan-buttner
Copy link
Contributor

@jonathan-buttner jonathan-buttner commented Sep 17, 2025

This PR adds custom headers support for text embedding for openai. This is the counter part to this PR: #134504

Example request

PUT _inference/text_embedding/openai { "service": "openai", "service_settings": { "api_key": "api key", "model_id": "text-embedding-3-small" }, "task_settings": { "headers": { "OpenAI-Organization": "org-id", "x-request-id": "jon-test" } } } 
@jonathan-buttner jonathan-buttner added >enhancement :ml Machine learning Team:ML Meta label for the ML team v9.2.0 labels Sep 17, 2025
@elasticsearchmachine
Copy link
Collaborator

Hi @jonathan-buttner, I've created a changelog YAML for you.

@jonathan-buttner jonathan-buttner marked this pull request as ready for review September 18, 2025 15:20
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

String user;

public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be overcomplicating things, but it should be possible to move the readTaskSettingsFromStream() and writeTo() implementations into the base class as well, by introducing an abstract method like abstract boolean shouldReadAdditionalString(TransportVersion version) which always returns false for OpenAiChatCompletionTaskSettings and checks the transport version for OpenAiEmbeddingsTaskSettings, and another abstract method that returns the transport version in which the headers were introduced for each class and which is used to determine whether to read/write the headers. The final result would look something like this for the reading side:

 public OpenAiTaskSettings(StreamInput in) throws IOException { String user; if (shouldReadAdditionalString(in.getTransportVersion())) { var discard = in.readString(); user = in.readOptionalString(); } else { user = in.readOptionalString(); } Map<String, String> headers; if (in.getTransportVersion().onOrAfter(headersIntroducedVersion())) { headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString); } else { headers = null; } taskSettings = user == null && headers == null ? EMPTY_SETTINGS : new Settings(user, headers); } 
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I agree that these changes would help with reducing the duplication. I think for this one I'd rather leave the transport version logic in the individual classes that need it. I think it might be a bit clearer when looking at the individual class as to what has changed between versions. If we end up making more transport version dependent changes to these two classes we can revisit pulling the logic into the base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add the changes for the empty settings though 👍

throw validationException;
}

return new Settings(user, stringHeaders);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a method like the one below, we can ensure that any time the settings would be empty, we use EMPTY_SETTINGS instead of constructing a new, empty Settings object. This would allow us to make the isEmpty() method a simple == comparison with EMPTY_SETTINGS and also prevent some unnecessary object allocations:

 private static Settings getSettingsCheckingForEmpty(String user, Map<String, String> stringHeaders) { if (user == null && (stringHeaders == null || stringHeaders.isEmpty())) { return EMPTY_SETTINGS; } else { return new Settings(user, stringHeaders); } } 

This method could also be used in the constructor below this.

I think this might lead to a subtle inconsistency with how we treat empty maps though, since passing a null user and an empty map to the constructor currently results in a Settings object that is not equal to EMPTY_SETTINGS, but still returns true from isEmpty(). Do we need to be able to differentiate between the "empty map" and "null map" versions of Settings? If so, then I don't think we can use EMPTY_SETTINGS when the user is null and the map is non-null but empty, but we can still use it when both are null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to be able to differentiate between the "empty map" and "null map" versions of Settings?

I think we're going to have a serialization issue if we internally convert an empty map to a null map because the testing logic will try to write the empty map and ensure that it will be read as an empty map. I suppose we could only do the conversion in the fromMap() function 🤔 . I suspect that it'll be pretty rare that folks include an empty headers field. And if they did include an empty headers field in the PUT request, when we create the inference endpoint it wouldn't write headers in the toXContent(). So I think to make things easier with the serialization tests I'm going to treat a null map and an empty map as separate things.

Comment on lines 32 to 35
var randomSettings = create(randomBoolean() ? null : "username", randomBoolean() ? null : Map.of("key", "value"));
var stringRep = Strings.toString(randomSettings);

assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, this test is not testing isEmpty() but rather whether isEmpty() and toXContent() agree. If there were bugs in both methods that just so happened to agree with each other, the test would still pass despite isEmpty() being incorrect.

Since there aren't many permutations of possible arguments to the create() method for this test, it might be worth explicitly testing all the combinations and making sure that isEmpty() returns the value expected:

 public void testIsEmpty() { var bothNull = create(null, null); assertThat(bothNull.isEmpty(), is(true)); var nullUserEmptyHeaders = create(null, Map.of()); assertThat(nullUserEmptyHeaders.isEmpty(), is(true)); var nullHeaders = create("user", null); assertThat(nullHeaders.isEmpty(), is(false)); var nullUser = create(null, Map.of("K", "v")); assertThat(nullUser.isEmpty(), is(false)); var neitherNull = create("user", Map.of("K", "v")); assertThat(neitherNull.isEmpty(), is(false)); } 

It's a shame that the parameterized testing framework we use doesn't support per-method parameterization, because that would make this a bit cleaner.

newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user());
}

if (newSettings.headers() != null && newSettings.headers().isEmpty() == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the implementation of createRandom(), the headers map is never empty, so we don't have coverage of the case where the headers in newSettings are an empty map. If I force that case by modifying the test, it fails due to expecting the updated settings to be an empty map, when they are in fact the original value from initialSettings. I'm not sure whether we expect an empty map in newSettings to overwrite the original value or not, but it would be good to explicitly test that behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I added a few tests for this and updated the createRandom to also generate empty headers.


protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers);

protected abstract T create(Map<String, Object> map);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion, it might be better to rename this createFromMap()

assertThat(settings.headers(), is(Map.of("key", "value")));
}

public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also have a test for the cases where the object stored at the HEADERS key in the map is an empty map, a map with only null values, and for when it's a map that doesn't contain Strings?

Comment on lines 93 to 111
public void testOf_KeepsOriginalValuesWithOverridesAreNull() {
var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));

assertThat(taskSettings.updatedTaskSettings(Map.of()), is(taskSettings));
}

public void testOf_UsesOverriddenSettings() {
var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));

assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.USER, "user2")), is(create("user2", null)));
}

public void testOf_UsesOverriddenSettings_ForHeaders() {
var user = "user";
var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, user)));

var headers = Map.of("key", "value");
assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.HEADERS, headers)), is(create(user, headers)));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no of() method on OpenAiTaskSettings, so these tests should either be renamed, or just removed, since I think they cover the same behaviour as the testUpdatedTaskSettings() test, just in a slightly different way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rename them 👍

var stringRep = Strings.toString(randomSettings);
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
}
public static Map<String, Object> getTaskSettingsMap(@Nullable String user) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems to only be called in OpenAiServiceTests, so could it be moved there and made non-static?

Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason it lives in OpenAiEmbeddingsTaskSettingsTests is because it's creating a map that is valid for the embeddings task settings. IMHO if I were writing additional tests that needed to construct a valid map for embedding task settings, I'd probably look in the embeddings task settings tests rather than in the service tests. Since it's only used in the service tests I can move it if you'd like though.

made non-static?

Just curious why you're advocating for this to be non-static since it doesn't reference any member variables?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the explanation for it being in OpenAiEmbeddingsTaskSettingsTests makes sense, although it is a little strange that it's not actually used there, even though most of the tests in that class could definitely be using it.

For making the method non-static, that was a brain fart, I meant private, although that would also defeat the purpose of this being a helper/utility method used by other classes.

As an aside, I just noticed that in OpenAiServiceTests.testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() we call this same static method to create task settings for a completion task, which is incorrect. Right now it happens to work because both the embeddings task settings and the completion task settings have a user field, but the logic for creating a completion-specific map should belong in a completion-specific class since their implementations now differ with the addition of the headers field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now differ with the addition of the headers field.

Yeah, with this PR the completions and embeddings logic will be the same again since they'll both have headers. I'll move the map creation into the base test class 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there already is one there: getOpenAiTaskSettingsMap, I'll remove this one.

* <a href="https://platform.openai.com/docs/api-reference/embeddings/create">see the openai docs for more details</a>
*/
public class OpenAiEmbeddingsTaskSettings implements TaskSettings {
public class OpenAiEmbeddingsTaskSettings extends OpenAiTaskSettings<OpenAiEmbeddingsTaskSettings> implements TaskSettings {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since OpenAiTaskSettings implements TaskSettings this class doesn't need to also implement it.

assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org"));
assertThat(httpPost.getLastHeader("key").getValue(), is("value"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the key and value Strings be extracted to variables?


public abstract class OpenAiTaskSettingsTests<T extends OpenAiTaskSettings<T>> extends AbstractBWCWireSerializationTestCase<T> {

private enum HeadersDefinition {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I like this solution.

@jonathan-buttner jonathan-buttner enabled auto-merge (squash) September 22, 2025 18:59
@jonathan-buttner jonathan-buttner merged commit 9600127 into elastic:main Sep 23, 2025
34 checks passed
szybia added a commit to szybia/elasticsearch that referenced this pull request Sep 23, 2025
…-dls * upstream/main: Bump FLEET_AGENTS_MAPPINGS_VERSION so the new mapping applies on upgrades (elastic#134957) [ML] Adding custom headers support openai text embeddings (elastic#134960) Fix systemd notify to use a shared arena (elastic#135235)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

>enhancement :ml Machine learning Team:ML Meta label for the ML team v9.2.0

3 participants