Skip to content

Commit 30e41d0

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add support to local agent run for agent eval
PiperOrigin-RevId: 841231423
1 parent 3eb38bf commit 30e41d0

File tree

5 files changed

+372
-57
lines changed

5 files changed

+372
-57
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,147 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error(
12891289
"'intermediate_events' or 'response' columns"
12901290
) in str(excinfo.value)
12911291

1292+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
1293+
@mock.patch("vertexai._genai._evals_common.InMemorySessionService")
1294+
@mock.patch("vertexai._genai._evals_common.Runner")
1295+
@mock.patch("vertexai._genai._evals_common.LlmAgent")
1296+
def test_run_inference_with_local_agent(
1297+
self,
1298+
mock_llm_agent,
1299+
mock_runner,
1300+
mock_session_service,
1301+
mock_eval_dataset_loader,
1302+
):
1303+
mock_df = pd.DataFrame(
1304+
{
1305+
"prompt": ["agent prompt", "agent prompt 2"],
1306+
"session_inputs": [
1307+
{
1308+
"user_id": "123",
1309+
"state": {"a": "1"},
1310+
},
1311+
{
1312+
"user_id": "456",
1313+
"state": {"b": "2"},
1314+
},
1315+
],
1316+
}
1317+
)
1318+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
1319+
orient="records"
1320+
)
1321+
1322+
mock_agent_instance = mock.Mock()
1323+
mock_llm_agent.return_value = mock_agent_instance
1324+
mock_session_service.return_value.create_session = mock.AsyncMock()
1325+
mock_runner_instance = mock_runner.return_value
1326+
stream_run_return_value_1 = [
1327+
mock.Mock(
1328+
model_dump=lambda: {
1329+
"id": "1",
1330+
"content": {"parts": [{"text": "intermediate1"}]},
1331+
"timestamp": 123,
1332+
"author": "model",
1333+
}
1334+
),
1335+
mock.Mock(
1336+
model_dump=lambda: {
1337+
"id": "2",
1338+
"content": {"parts": [{"text": "agent response"}]},
1339+
"timestamp": 124,
1340+
"author": "model",
1341+
}
1342+
),
1343+
]
1344+
stream_run_return_value_2 = [
1345+
mock.Mock(
1346+
model_dump=lambda: {
1347+
"id": "3",
1348+
"content": {"parts": [{"text": "intermediate2"}]},
1349+
"timestamp": 125,
1350+
"author": "model",
1351+
}
1352+
),
1353+
mock.Mock(
1354+
model_dump=lambda: {
1355+
"id": "4",
1356+
"content": {"parts": [{"text": "agent response 2"}]},
1357+
"timestamp": 126,
1358+
"author": "model",
1359+
}
1360+
),
1361+
]
1362+
1363+
async def async_iterator(items):
1364+
for item in items:
1365+
yield item
1366+
1367+
def run_async_side_effect(*args, **kwargs):
1368+
new_message = kwargs.get("new_message")
1369+
if new_message and new_message.parts[0].text == "agent prompt":
1370+
return async_iterator(stream_run_return_value_1)
1371+
return async_iterator(stream_run_return_value_2)
1372+
1373+
mock_runner_instance.run_async.side_effect = run_async_side_effect
1374+
1375+
inference_result = self.client.evals.run_inference(
1376+
agent=mock_agent_instance,
1377+
src=mock_df,
1378+
)
1379+
1380+
mock_eval_dataset_loader.return_value.load.assert_called_once_with(mock_df)
1381+
assert mock_session_service.call_count == 2
1382+
mock_runner.assert_called_with(
1383+
agent=mock_agent_instance,
1384+
app_name="local agent run",
1385+
session_service=mock_session_service.return_value,
1386+
)
1387+
assert mock_runner.call_count == 2
1388+
assert mock_runner_instance.run_async.call_count == 2
1389+
1390+
expected_df = pd.DataFrame(
1391+
{
1392+
"prompt": ["agent prompt", "agent prompt 2"],
1393+
"session_inputs": [
1394+
{
1395+
"user_id": "123",
1396+
"state": {"a": "1"},
1397+
},
1398+
{
1399+
"user_id": "456",
1400+
"state": {"b": "2"},
1401+
},
1402+
],
1403+
"intermediate_events": [
1404+
[
1405+
{
1406+
"event_id": "1",
1407+
"content": {"parts": [{"text": "intermediate1"}]},
1408+
"creation_timestamp": 123,
1409+
"author": "model",
1410+
}
1411+
],
1412+
[
1413+
{
1414+
"event_id": "3",
1415+
"content": {"parts": [{"text": "intermediate2"}]},
1416+
"creation_timestamp": 125,
1417+
"author": "model",
1418+
}
1419+
],
1420+
],
1421+
"response": ["agent response", "agent response 2"],
1422+
}
1423+
)
1424+
pd.testing.assert_frame_equal(
1425+
inference_result.eval_dataset_df.sort_values(by="prompt").reset_index(
1426+
drop=True
1427+
),
1428+
expected_df.sort_values(by="prompt").reset_index(drop=True),
1429+
)
1430+
assert inference_result.candidate_name is None
1431+
assert inference_result.gcs_source is None
1432+
12921433
def test_run_inference_with_litellm_string_prompt_format(
12931434
self,
12941435
mock_api_client_fixture,
@@ -1641,6 +1782,7 @@ def test_run_agent_internal_success(self, mock_run_agent):
16411782
result_df = _evals_common._run_agent_internal(
16421783
api_client=mock_api_client,
16431784
agent_engine=mock_agent_engine,
1785+
agent=None,
16441786
prompt_dataset=prompt_dataset,
16451787
)
16461788

@@ -1671,6 +1813,7 @@ def test_run_agent_internal_error_response(self, mock_run_agent):
16711813
result_df = _evals_common._run_agent_internal(
16721814
api_client=mock_api_client,
16731815
agent_engine=mock_agent_engine,
1816+
agent=None,
16741817
prompt_dataset=prompt_dataset,
16751818
)
16761819

@@ -1697,6 +1840,7 @@ def test_run_agent_internal_malformed_event(self, mock_run_agent):
16971840
result_df = _evals_common._run_agent_internal(
16981841
api_client=mock_api_client,
16991842
agent_engine=mock_agent_engine,
1843+
agent=None,
17001844
prompt_dataset=prompt_dataset,
17011845
)
17021846
assert "response" in result_df.columns

0 commit comments

Comments
 (0)