Skip to content

Commit a52198a

Browse files
sararobcopybara-github
authored andcommitted
feat: GenAI SDK client(evals) - Add async evaluate_instances method
PiperOrigin-RevId: 784173496
1 parent 74310d3 commit a52198a

File tree

4 files changed

+110
-8
lines changed

4 files changed

+110
-8
lines changed

tests/unit/vertexai/genai/replays/test_batch_evaluate.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import pytest
18+
1719
from tests.unit.vertexai.genai.replays import pytest_helper
1820
from vertexai._genai import types
1921

@@ -43,3 +45,24 @@ def test_batch_eval(client):
4345
globals_for_file=globals(),
4446
test_method="evals.batch_evaluate",
4547
)
48+
49+
pytest_plugins = ("pytest_asyncio",)
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_batch_eval_async(client):
54+
eval_dataset = types.EvaluationDataset(
55+
gcs_source=types.GcsSource(
56+
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
57+
)
58+
)
59+
60+
response = await client.aio.evals.batch_evaluate(
61+
dataset=eval_dataset,
62+
metrics=[
63+
types.PrebuiltMetric.TEXT_QUALITY,
64+
],
65+
dest="gs://genai-eval-sdk-replay-test/test_data/batch_eval_output",
66+
)
67+
assert "operations" in response.name
68+
assert "EvaluateDatasetOperationMetadata" in response.metadata.get("@type")

tests/unit/vertexai/genai/replays/test_evaluate_instances.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import json
1718

1819
from tests.unit.vertexai.genai.replays import pytest_helper
1920
from vertexai._genai import types
2021
import pandas as pd
21-
import json
22+
import pytest
2223

2324

2425
def test_bleu_metric(client):
@@ -31,7 +32,11 @@ def test_bleu_metric(client):
3132
],
3233
metric_spec=types.BleuSpec(),
3334
)
34-
response = client.evals._evaluate_instances(bleu_input=test_bleu_input)
35+
response = client.evals.evaluate_instances(
36+
metric_config=types._EvaluateInstancesRequestParameters(
37+
bleu_input=test_bleu_input
38+
)
39+
)
3540
assert len(response.bleu_results.bleu_metric_values) == 1
3641

3742

@@ -46,8 +51,10 @@ def test_exact_match_metric(client):
4651
],
4752
metric_spec=types.ExactMatchSpec(),
4853
)
49-
response = client.evals._evaluate_instances(
50-
exact_match_input=test_exact_match_input
54+
response = client.evals.evaluate_instances(
55+
metric_config=types._EvaluateInstancesRequestParameters(
56+
exact_match_input=test_exact_match_input
57+
)
5158
)
5259
assert len(response.exact_match_results.exact_match_metric_values) == 1
5360

@@ -63,7 +70,11 @@ def test_rouge_metric(client):
6370
],
6471
metric_spec=types.RougeSpec(rouge_type="rougeL"),
6572
)
66-
response = client.evals._evaluate_instances(rouge_input=test_rouge_input)
73+
response = client.evals.evaluate_instances(
74+
metric_config=types._EvaluateInstancesRequestParameters(
75+
rouge_input=test_rouge_input
76+
)
77+
)
6778
assert len(response.rouge_results.rouge_metric_values) == 1
6879

6980

@@ -78,7 +89,11 @@ def test_pointwise_metric(client):
7889
metric_prompt_template="Evaluate if the response '{response}' correctly answers the prompt '{prompt}'."
7990
),
8091
)
81-
response = client.evals._evaluate_instances(pointwise_metric_input=test_input)
92+
response = client.evals.evaluate_instances(
93+
metric_config=types._EvaluateInstancesRequestParameters(
94+
pointwise_metric_input=test_input
95+
)
96+
)
8297
assert response.pointwise_metric_result is not None
8398
assert response.pointwise_metric_result.score is not None
8499

@@ -100,8 +115,10 @@ def test_pairwise_metric_with_autorater(client):
100115
)
101116
autorater_config = types.AutoraterConfig(sampling_count=2)
102117

103-
response = client.evals._evaluate_instances(
104-
pairwise_metric_input=test_input, autorater_config=autorater_config
118+
response = client.evals.evaluate_instances(
119+
metric_config=types._EvaluateInstancesRequestParameters(
120+
pairwise_metric_input=test_input, autorater_config=autorater_config
121+
)
105122
)
106123
assert response.pairwise_metric_result is not None
107124
assert response.pairwise_metric_result.pairwise_choice is not None
@@ -147,3 +164,25 @@ def test_inference_with_prompt_template(client):
147164
globals_for_file=globals(),
148165
test_method="evals.evaluate",
149166
)
167+
168+
169+
pytest_plugins = ("pytest_asyncio",)
170+
171+
172+
@pytest.mark.asyncio
173+
async def test_bleu_metric_async(client):
174+
test_bleu_input = types.BleuInput(
175+
instances=[
176+
types.BleuInstance(
177+
reference="The quick brown fox jumps over the lazy dog.",
178+
prediction="A fast brown fox leaps over a lazy dog.",
179+
)
180+
],
181+
metric_spec=types.BleuSpec(),
182+
)
183+
response = await client.aio.evals.evaluate_instances(
184+
metric_config=types._EvaluateInstancesRequestParameters(
185+
bleu_input=test_bleu_input
186+
)
187+
)
188+
assert len(response.bleu_results.bleu_metric_values) == 1

vertexai/_genai/evals.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,3 +1522,21 @@ async def batch_evaluate(
15221522
self._api_client._verify_response(return_value)
15231523

15241524
return return_value
1525+
1526+
async def evaluate_instances(
1527+
self,
1528+
*,
1529+
metric_config: types._EvaluateInstancesRequestParameters,
1530+
) -> types.EvaluateInstancesResponse:
1531+
"""Evaluates an instance of a model."""
1532+
1533+
if isinstance(metric_config, types._EvaluateInstancesRequestParameters):
1534+
metric_config = metric_config.model_dump()
1535+
else:
1536+
metric_config = dict(metric_config)
1537+
1538+
result = await self._evaluate_instances(
1539+
**metric_config,
1540+
)
1541+
1542+
return result

vertexai/_genai/mypy.ini

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[mypy]
2+
# TODO(b/422425982): Fix arg-type errors
3+
disable_error_code = import-not-found, import-untyped, arg-type
4+
5+
# We only want to run mypy on _genai dir, ignore dependent modules
6+
[mypy-vertexai.agent_engines.*]
7+
ignore_errors = True
8+
9+
[mypy-vertexai.preview.*]
10+
ignore_errors = True
11+
12+
[mypy-vertexai.generative_models.*]
13+
ignore_errors = True
14+
15+
[mypy-vertexai.prompts.*]
16+
ignore_errors = True
17+
18+
[mypy-vertexai.tuning.*]
19+
ignore_errors = True
20+
21+
[mypy-vertexai.caching.*]
22+
ignore_errors = True

0 commit comments

Comments
 (0)