Skip to content

Commit d7d4e2c

Browse files
authored
Merge branch 'main' into docs/add-weighted-rrf-documentation
2 parents 40585fc + af7b97d commit d7d4e2c

File tree

6 files changed

+26
-41
lines changed

6 files changed

+26
-41
lines changed

server/src/main/java/org/elasticsearch/inference/telemetry/InferenceStats.java

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) {
4949
);
5050
}
5151

52-
public static Map<String, Object> modelAttributes(Model model) {
53-
var modelAttributesMap = new HashMap<String, Object>();
54-
modelAttributesMap.put("service", model.getConfigurations().getService());
55-
modelAttributesMap.put("task_type", model.getTaskType().toString());
56-
57-
if (Objects.nonNull(model.getServiceSettings().modelId())) {
58-
modelAttributesMap.put("model_id", model.getServiceSettings().modelId());
59-
}
60-
61-
return modelAttributesMap;
52+
public static Map<String, Object> serviceAttributes(Model model) {
53+
return Map.of("service", model.getConfigurations().getService(), "task_type", model.getTaskType().toString());
6254
}
6355

64-
public static Map<String, Object> modelAttributes(UnparsedModel model) {
56+
public static Map<String, Object> serviceAttributes(UnparsedModel model) {
6557
return Map.of("service", model.service(), "task_type", model.taskType().toString());
6658
}
6759

@@ -77,9 +69,9 @@ public static Map<String, Object> responseAttributes(@Nullable Throwable throwab
7769
return Map.of("error.type", throwable.getClass().getSimpleName());
7870
}
7971

80-
public static Map<String, Object> modelAndResponseAttributes(Model model, @Nullable Throwable throwable) {
72+
public static Map<String, Object> serviceAndResponseAttributes(Model model, @Nullable Throwable throwable) {
8173
var metricAttributes = new HashMap<String, Object>();
82-
metricAttributes.putAll(modelAttributes(model));
74+
metricAttributes.putAll(serviceAttributes(model));
8375
metricAttributes.putAll(responseAttributes(throwable));
8476
return metricAttributes;
8577
}

server/src/test/java/org/elasticsearch/inference/telemetry/InferenceStatsTests.java

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import java.util.Map;
2626

2727
import static org.elasticsearch.inference.telemetry.InferenceStats.create;
28-
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes;
2928
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
29+
import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAttributes;
3030
import static org.hamcrest.Matchers.is;
3131
import static org.hamcrest.Matchers.nullValue;
3232
import static org.mockito.ArgumentMatchers.assertArg;
@@ -41,23 +41,20 @@ public static InferenceStats mockInferenceStats() {
4141
return new InferenceStats(mock(), mock(), mock());
4242
}
4343

44-
public void testRecordWithModel() {
44+
public void testRecordWithService() {
4545
var longCounter = mock(LongCounter.class);
4646
var stats = new InferenceStats(longCounter, mock(), mock());
4747

48-
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));
48+
stats.requestCount().incrementBy(1, serviceAttributes(model("service", TaskType.ANY, "modelId")));
4949

50-
verify(longCounter).incrementBy(
51-
eq(1L),
52-
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
53-
);
50+
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
5451
}
5552

5653
public void testRecordWithoutModel() {
5754
var longCounter = mock(LongCounter.class);
5855
var stats = new InferenceStats(longCounter, mock(), mock());
5956

60-
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));
57+
stats.requestCount().incrementBy(1, serviceAttributes(model("service", TaskType.ANY, null)));
6158

6259
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
6360
}
@@ -72,15 +69,14 @@ public void testRecordDurationWithoutError() {
7269
var stats = new InferenceStats(mock(), histogramCounter, mock());
7370

7471
Map<String, Object> metricAttributes = new HashMap<>();
75-
metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId")));
72+
metricAttributes.putAll(serviceAttributes(model("service", TaskType.ANY, "modelId")));
7673
metricAttributes.putAll(responseAttributes(null));
7774

7875
stats.inferenceDuration().record(expectedLong, metricAttributes);
7976

8077
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
8178
assertThat(attributes.get("service"), is("service"));
8279
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
83-
assertThat(attributes.get("model_id"), is("modelId"));
8480
assertThat(attributes.get("status_code"), is(200));
8581
assertThat(attributes.get("error.type"), nullValue());
8682
}));
@@ -100,15 +96,14 @@ public void testRecordDurationWithElasticsearchStatusException() {
10096
var expectedError = String.valueOf(statusCode.getStatus());
10197

10298
Map<String, Object> metricAttributes = new HashMap<>();
103-
metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId")));
99+
metricAttributes.putAll(serviceAttributes(model("service", TaskType.ANY, "modelId")));
104100
metricAttributes.putAll(responseAttributes(exception));
105101

106102
stats.inferenceDuration().record(expectedLong, metricAttributes);
107103

108104
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
109105
assertThat(attributes.get("service"), is("service"));
110106
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
111-
assertThat(attributes.get("model_id"), is("modelId"));
112107
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
113108
assertThat(attributes.get("error.type"), is(expectedError));
114109
}));
@@ -127,15 +122,14 @@ public void testRecordDurationWithOtherException() {
127122
var expectedError = exception.getClass().getSimpleName();
128123

129124
Map<String, Object> metricAttributes = new HashMap<>();
130-
metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId")));
125+
metricAttributes.putAll(serviceAttributes(model("service", TaskType.ANY, "modelId")));
131126
metricAttributes.putAll(responseAttributes(exception));
132127

133128
stats.inferenceDuration().record(expectedLong, metricAttributes);
134129

135130
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
136131
assertThat(attributes.get("service"), is("service"));
137132
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
138-
assertThat(attributes.get("model_id"), is("modelId"));
139133
assertThat(attributes.get("status_code"), nullValue());
140134
assertThat(attributes.get("error.type"), is(expectedError));
141135
}));
@@ -152,7 +146,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException()
152146
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
153147

154148
Map<String, Object> metricAttributes = new HashMap<>();
155-
metricAttributes.putAll(modelAttributes(unparsedModel));
149+
metricAttributes.putAll(serviceAttributes(unparsedModel));
156150
metricAttributes.putAll(responseAttributes(exception));
157151

158152
stats.inferenceDuration().record(expectedLong, metricAttributes);
@@ -176,7 +170,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() {
176170
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
177171

178172
Map<String, Object> metricAttributes = new HashMap<>();
179-
metricAttributes.putAll(modelAttributes(unparsedModel));
173+
metricAttributes.putAll(serviceAttributes(unparsedModel));
180174
metricAttributes.putAll(responseAttributes(exception));
181175

182176
stats.inferenceDuration().record(expectedLong, metricAttributes);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@
4949

5050
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
5151
import static org.elasticsearch.core.Strings.format;
52-
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes;
53-
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAttributes;
5452
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
53+
import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes;
5554
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
5655

5756
/**
@@ -181,7 +180,7 @@ private static void validationHelper(Supplier<Boolean> validationFailure, Suppli
181180

182181
private void recordRequestDurationMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
183182
Map<String, Object> metricAttributes = new HashMap<>();
184-
metricAttributes.putAll(modelAttributes(model));
183+
metricAttributes.putAll(InferenceStats.serviceAttributes(model));
185184
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
186185

187186
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
@@ -270,7 +269,7 @@ protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
270269

271270
private void recordRequestCountMetrics(Model model, Request request, String localNodeId) {
272271
Map<String, Object> requestCountAttributes = new HashMap<>();
273-
requestCountAttributes.putAll(modelAttributes(model));
272+
requestCountAttributes.putAll(InferenceStats.serviceAttributes(model));
274273

275274
inferenceStats.requestCount().incrementBy(1, requestCountAttributes);
276275
}
@@ -283,7 +282,7 @@ private void recordRequestDurationMetrics(
283282
@Nullable Throwable t
284283
) {
285284
Map<String, Object> metricAttributes = new HashMap<>();
286-
metricAttributes.putAll(modelAndResponseAttributes(model, unwrapCause(t)));
285+
metricAttributes.putAll(serviceAndResponseAttributes(model, unwrapCause(t)));
287286

288287
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
289288
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
import java.util.Map;
7878
import java.util.stream.Collectors;
7979

80-
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes;
80+
import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes;
8181
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
8282
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
8383
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
@@ -461,7 +461,7 @@ public void onFailure(Exception exc) {
461461

462462
private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) {
463463
Map<String, Object> requestCountAttributes = new HashMap<>();
464-
requestCountAttributes.putAll(modelAndResponseAttributes(model, throwable));
464+
requestCountAttributes.putAll(serviceAndResponseAttributes(model, throwable));
465465
requestCountAttributes.put("inference_source", "semantic_text_bulk");
466466
inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes);
467467
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
4949
import static org.elasticsearch.core.Strings.format;
50-
import static org.elasticsearch.inference.telemetry.InferenceStats.modelAndResponseAttributes;
50+
import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes;
5151
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
5252
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
5353

@@ -126,7 +126,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
126126
});
127127
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);
128128
subscribableListener.addListener(ActionListener.wrap(started -> {
129-
inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, null));
129+
inferenceStats.deploymentDuration().record(timer.elapsedMillis(), serviceAndResponseAttributes(model, null));
130130
finalListener.onResponse(started);
131131
}, e -> {
132132
if (e instanceof ElasticsearchTimeoutException) {
@@ -139,10 +139,11 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
139139
model.getInferenceEntityId()
140140
)
141141
);
142-
inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, timeoutException));
142+
inferenceStats.deploymentDuration()
143+
.record(timer.elapsedMillis(), serviceAndResponseAttributes(model, timeoutException));
143144
finalListener.onFailure(timeoutException);
144145
} else {
145-
inferenceStats.deploymentDuration().record(timer.elapsedMillis(), modelAndResponseAttributes(model, unwrapCause(e)));
146+
inferenceStats.deploymentDuration().record(timer.elapsedMillis(), serviceAndResponseAttributes(model, unwrapCause(e)));
146147
finalListener.onFailure(e);
147148
}
148149
}));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ public void testItemFailures() throws Exception {
356356
assertThat(statusCode, is(200));
357357
}
358358
assertThat(attributes.get("task_type"), is(model.getTaskType().toString()));
359-
assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId()));
360359
assertThat(attributes.get("service"), is(model.getConfigurations().getService()));
361360
assertThat(attributes.get("inference_source"), is("semantic_text_bulk"));
362361
}));

0 commit comments

Comments
 (0)