@@ -1289,6 +1289,147 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error(
12891289 "'intermediate_events' or 'response' columns"
12901290 ) in str (excinfo .value )
12911291
1292+ @mock .patch .object (_evals_utils , "EvalDatasetLoader" )
1293+ @mock .patch ("vertexai._genai._evals_common.InMemorySessionService" )
1294+ @mock .patch ("vertexai._genai._evals_common.Runner" )
1295+ @mock .patch ("vertexai._genai._evals_common.LlmAgent" )
1296+ def test_run_inference_with_local_agent (
1297+ self ,
1298+ mock_llm_agent ,
1299+ mock_runner ,
1300+ mock_session_service ,
1301+ mock_eval_dataset_loader ,
1302+ ):
1303+ mock_df = pd .DataFrame (
1304+ {
1305+ "prompt" : ["agent prompt" , "agent prompt 2" ],
1306+ "session_inputs" : [
1307+ {
1308+ "user_id" : "123" ,
1309+ "state" : {"a" : "1" },
1310+ },
1311+ {
1312+ "user_id" : "456" ,
1313+ "state" : {"b" : "2" },
1314+ },
1315+ ],
1316+ }
1317+ )
1318+ mock_eval_dataset_loader .return_value .load .return_value = mock_df .to_dict (
1319+ orient = "records"
1320+ )
1321+
1322+ mock_agent_instance = mock .Mock ()
1323+ mock_llm_agent .return_value = mock_agent_instance
1324+ mock_session_service .return_value .create_session = mock .AsyncMock ()
1325+ mock_runner_instance = mock_runner .return_value
1326+ stream_run_return_value_1 = [
1327+ mock .Mock (
1328+ model_dump = lambda : {
1329+ "id" : "1" ,
1330+ "content" : {"parts" : [{"text" : "intermediate1" }]},
1331+ "timestamp" : 123 ,
1332+ "author" : "model" ,
1333+ }
1334+ ),
1335+ mock .Mock (
1336+ model_dump = lambda : {
1337+ "id" : "2" ,
1338+ "content" : {"parts" : [{"text" : "agent response" }]},
1339+ "timestamp" : 124 ,
1340+ "author" : "model" ,
1341+ }
1342+ ),
1343+ ]
1344+ stream_run_return_value_2 = [
1345+ mock .Mock (
1346+ model_dump = lambda : {
1347+ "id" : "3" ,
1348+ "content" : {"parts" : [{"text" : "intermediate2" }]},
1349+ "timestamp" : 125 ,
1350+ "author" : "model" ,
1351+ }
1352+ ),
1353+ mock .Mock (
1354+ model_dump = lambda : {
1355+ "id" : "4" ,
1356+ "content" : {"parts" : [{"text" : "agent response 2" }]},
1357+ "timestamp" : 126 ,
1358+ "author" : "model" ,
1359+ }
1360+ ),
1361+ ]
1362+
1363+ async def async_iterator (items ):
1364+ for item in items :
1365+ yield item
1366+
1367+ def run_async_side_effect (* args , ** kwargs ):
1368+ new_message = kwargs .get ("new_message" )
1369+ if new_message and new_message .parts [0 ].text == "agent prompt" :
1370+ return async_iterator (stream_run_return_value_1 )
1371+ return async_iterator (stream_run_return_value_2 )
1372+
1373+ mock_runner_instance .run_async .side_effect = run_async_side_effect
1374+
1375+ inference_result = self .client .evals .run_inference (
1376+ agent = mock_agent_instance ,
1377+ src = mock_df ,
1378+ )
1379+
1380+ mock_eval_dataset_loader .return_value .load .assert_called_once_with (mock_df )
1381+ assert mock_session_service .call_count == 2
1382+ mock_runner .assert_called_with (
1383+ agent = mock_agent_instance ,
1384+ app_name = "local agent run" ,
1385+ session_service = mock_session_service .return_value ,
1386+ )
1387+ assert mock_runner .call_count == 2
1388+ assert mock_runner_instance .run_async .call_count == 2
1389+
1390+ expected_df = pd .DataFrame (
1391+ {
1392+ "prompt" : ["agent prompt" , "agent prompt 2" ],
1393+ "session_inputs" : [
1394+ {
1395+ "user_id" : "123" ,
1396+ "state" : {"a" : "1" },
1397+ },
1398+ {
1399+ "user_id" : "456" ,
1400+ "state" : {"b" : "2" },
1401+ },
1402+ ],
1403+ "intermediate_events" : [
1404+ [
1405+ {
1406+ "event_id" : "1" ,
1407+ "content" : {"parts" : [{"text" : "intermediate1" }]},
1408+ "creation_timestamp" : 123 ,
1409+ "author" : "model" ,
1410+ }
1411+ ],
1412+ [
1413+ {
1414+ "event_id" : "3" ,
1415+ "content" : {"parts" : [{"text" : "intermediate2" }]},
1416+ "creation_timestamp" : 125 ,
1417+ "author" : "model" ,
1418+ }
1419+ ],
1420+ ],
1421+ "response" : ["agent response" , "agent response 2" ],
1422+ }
1423+ )
1424+ pd .testing .assert_frame_equal (
1425+ inference_result .eval_dataset_df .sort_values (by = "prompt" ).reset_index (
1426+ drop = True
1427+ ),
1428+ expected_df .sort_values (by = "prompt" ).reset_index (drop = True ),
1429+ )
1430+ assert inference_result .candidate_name is None
1431+ assert inference_result .gcs_source is None
1432+
12921433 def test_run_inference_with_litellm_string_prompt_format (
12931434 self ,
12941435 mock_api_client_fixture ,
@@ -1641,6 +1782,7 @@ def test_run_agent_internal_success(self, mock_run_agent):
16411782 result_df = _evals_common ._run_agent_internal (
16421783 api_client = mock_api_client ,
16431784 agent_engine = mock_agent_engine ,
1785+ agent = None ,
16441786 prompt_dataset = prompt_dataset ,
16451787 )
16461788
@@ -1671,6 +1813,7 @@ def test_run_agent_internal_error_response(self, mock_run_agent):
16711813 result_df = _evals_common ._run_agent_internal (
16721814 api_client = mock_api_client ,
16731815 agent_engine = mock_agent_engine ,
1816+ agent = None ,
16741817 prompt_dataset = prompt_dataset ,
16751818 )
16761819
@@ -1697,6 +1840,7 @@ def test_run_agent_internal_malformed_event(self, mock_run_agent):
16971840 result_df = _evals_common ._run_agent_internal (
16981841 api_client = mock_api_client ,
16991842 agent_engine = mock_agent_engine ,
1843+ agent = None ,
17001844 prompt_dataset = prompt_dataset ,
17011845 )
17021846 assert "response" in result_df .columns
0 commit comments