- Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adding custom headers support openai text embeddings #134960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Adding custom headers support openai text embeddings #134960
Conversation
Hi @jonathan-buttner, I've created a changelog YAML for you. |
…i-headers-embedding
…tner/elasticsearch into ml-openai-headers-embedding
Pinging @elastic/ml-core (Team:ML) |
String user; | ||
| ||
public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { | ||
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be overcomplicating things, but it should be possible to move the readTaskSettingsFromStream()
and writeTo()
implementations into the base class as well, by introducing an abstract method like abstract boolean shouldReadAdditionalString(TransportVersion version)
which always returns false
for OpenAiChatCompletionTaskSettings
and checks the transport version for OpenAiEmbeddingsTaskSettings
, and another abstract method that returns the transport version in which the headers were introduced for each class and which is used to determine whether to read/write the headers. The final result would look something like this for the reading side:
public OpenAiTaskSettings(StreamInput in) throws IOException { String user; if (shouldReadAdditionalString(in.getTransportVersion())) { var discard = in.readString(); user = in.readOptionalString(); } else { user = in.readOptionalString(); } Map<String, String> headers; if (in.getTransportVersion().onOrAfter(headersIntroducedVersion())) { headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString); } else { headers = null; } taskSettings = user == null && headers == null ? EMPTY_SETTINGS : new Settings(user, headers); }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. I agree that these changes would help with reducing the duplication. I think for this one I'd rather leave the transport version logic in the individual classes that need it. I think it might be a bit clearer when looking at the individual class as to what has changed between versions. If we end up making more transport version dependent changes to these two classes we can revisit pulling the logic into the base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add the changes for the empty settings though 👍
throw validationException; | ||
} | ||
| ||
return new Settings(user, stringHeaders); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a method like the one below, we can ensure that any time the settings would be empty, we use EMPTY_SETTINGS
instead of constructing a new, empty Settings
object. This would allow us to make the isEmpty()
method a simple ==
comparison with EMPTY_SETTINGS
and also prevent some unnecessary object allocations:
private static Settings getSettingsCheckingForEmpty(String user, Map<String, String> stringHeaders) { if (user == null && (stringHeaders == null || stringHeaders.isEmpty())) { return EMPTY_SETTINGS; } else { return new Settings(user, stringHeaders); } }
This method could also be used in the constructor below this.
I think this might lead to a subtle inconsistency with how we treat empty maps though, since passing a null user and an empty map to the constructor currently results in a Settings
object that is not equal to EMPTY_SETTINGS
, but still returns true from isEmpty()
. Do we need to be able to differentiate between the "empty map" and "null map" versions of Settings
? If so, then I don't think we can use EMPTY_SETTINGS
when the user is null and the map is non-null but empty, but we can still use it when both are null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to be able to differentiate between the "empty map" and "null map" versions of Settings?
I think we're going to have a serialization issue if we internally convert an empty map to a null map because the testing logic will try to write the empty map and ensure that it will be read as an empty map. I suppose we could only do the conversion in the fromMap()
function 🤔 . I suspect that it'll be pretty rare that folks include an empty headers
field. And if they did include an empty headers
field in the PUT request, when we create the inference endpoint it wouldn't write headers in the toXContent()
. So I think to make things easier with the serialization tests I'm going to treat a null map and an empty map as separate things.
var randomSettings = create(randomBoolean() ? null : "username", randomBoolean() ? null : Map.of("key", "value")); | ||
var stringRep = Strings.toString(randomSettings); | ||
| ||
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking, this test is not testing isEmpty()
but rather whether isEmpty()
and toXContent()
agree. If there were bugs in both methods that just so happened to agree with each other, the test would still pass despite isEmpty()
being incorrect.
Since there aren't many permutations of possible arguments to the create()
method for this test, it might be worth explicitly testing all the combinations and making sure that isEmpty()
returns the value expected:
public void testIsEmpty() { var bothNull = create(null, null); assertThat(bothNull.isEmpty(), is(true)); var nullUserEmptyHeaders = create(null, Map.of()); assertThat(nullUserEmptyHeaders.isEmpty(), is(true)); var nullHeaders = create("user", null); assertThat(nullHeaders.isEmpty(), is(false)); var nullUser = create(null, Map.of("K", "v")); assertThat(nullUser.isEmpty(), is(false)); var neitherNull = create("user", Map.of("K", "v")); assertThat(neitherNull.isEmpty(), is(false)); }
It's a shame that the parameterized testing framework we use doesn't support per-method parameterization, because that would make this a bit cleaner.
newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); | ||
} | ||
| ||
if (newSettings.headers() != null && newSettings.headers().isEmpty() == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the implementation of createRandom()
, the headers map is never empty, so we don't have coverage of the case where the headers in newSettings
are an empty map. If I force that case by modifying the test, it fails due to expecting the updated settings to be an empty map, when they are in fact the original value from initialSettings
. I'm not sure whether we expect an empty map in newSettings
to overwrite the original value or not, but it would be good to explicitly test that behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I added a few tests for this and updated the createRandom
to also generate empty headers.
| ||
protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers); | ||
| ||
protected abstract T create(Map<String, Object> map); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid confusion, it might be better to rename this createFromMap()
assertThat(settings.headers(), is(Map.of("key", "value"))); | ||
} | ||
| ||
public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also have a test for the cases where the object stored at the HEADERS
key in the map is an empty map, a map with only null values, and for when it's a map that doesn't contain Strings?
public void testOf_KeepsOriginalValuesWithOverridesAreNull() { | ||
var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); | ||
| ||
assertThat(taskSettings.updatedTaskSettings(Map.of()), is(taskSettings)); | ||
} | ||
| ||
public void testOf_UsesOverriddenSettings() { | ||
var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); | ||
| ||
assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.USER, "user2")), is(create("user2", null))); | ||
} | ||
| ||
public void testOf_UsesOverriddenSettings_ForHeaders() { | ||
var user = "user"; | ||
var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, user))); | ||
| ||
var headers = Map.of("key", "value"); | ||
assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.HEADERS, headers)), is(create(user, headers))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no of()
method on OpenAiTaskSettings
, so these tests should either be renamed, or just removed, since I think they cover the same behaviour as the testUpdatedTaskSettings()
test, just in a slightly different way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rename them 👍
var stringRep = Strings.toString(randomSettings); | ||
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); | ||
} | ||
public static Map<String, Object> getTaskSettingsMap(@Nullable String user) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method seems to only be called in OpenAiServiceTests
, so could it be moved there and made non-static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason it lives in OpenAiEmbeddingsTaskSettingsTests is because it's creating a map that is valid for the embeddings task settings. IMHO if I were writing additional tests that needed to construct a valid map for embedding task settings, I'd probably look in the embeddings task settings tests rather than in the service tests. Since it's only used in the service tests I can move it if you'd like though.
made non-static?
Just curious why you're advocating for this to be non-static since it doesn't reference any member variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, the explanation for it being in OpenAiEmbeddingsTaskSettingsTests
makes sense, although it is a little strange that it's not actually used there, even though most of the tests in that class could definitely be using it.
For making the method non-static, that was a brain fart, I meant private, although that would also defeat the purpose of this being a helper/utility method used by other classes.
As an aside, I just noticed that in OpenAiServiceTests.testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel()
we call this same static method to create task settings for a completion task, which is incorrect. Right now it happens to work because both the embeddings task settings and the completion task settings have a user
field, but the logic for creating a completion-specific map should belong in a completion-specific class since their implementations now differ with the addition of the headers
field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now differ with the addition of the headers field.
Yeah, with this PR the completions and embeddings logic will be the same again since they'll both have headers. I'll move the map creation into the base test class 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually there already is one there: getOpenAiTaskSettingsMap
, I'll remove this one.
* <a href="https://platform.openai.com/docs/api-reference/embeddings/create">see the openai docs for more details</a> | ||
*/ | ||
public class OpenAiEmbeddingsTaskSettings implements TaskSettings { | ||
public class OpenAiEmbeddingsTaskSettings extends OpenAiTaskSettings<OpenAiEmbeddingsTaskSettings> implements TaskSettings { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since OpenAiTaskSettings
implements TaskSettings
this class doesn't need to also implement it.
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); | ||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); | ||
assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); | ||
assertThat(httpPost.getLastHeader("key").getValue(), is("value")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the key
and value
Strings be extracted to variables?
| ||
public abstract class OpenAiTaskSettingsTests<T extends OpenAiTaskSettings<T>> extends AbstractBWCWireSerializationTestCase<T> { | ||
| ||
private enum HeadersDefinition { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I like this solution.
…tner/elasticsearch into ml-openai-headers-embedding
…i-headers-embedding
…tner/elasticsearch into ml-openai-headers-embedding
…-dls * upstream/main: Bump FLEET_AGENTS_MAPPINGS_VERSION so the new mapping applies on upgrades (elastic#134957) [ML] Adding custom headers support openai text embeddings (elastic#134960) Fix systemd notify to use a shared arena (elastic#135235)
This PR adds custom headers support for text embedding for openai. This is the counter part to this PR: #134504
Example request