Skip to content

Commit 7c8c218

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
fix: revert: Alow VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 827710498
1 parent 0143c07 commit 7c8c218

File tree

2 files changed

+39
-61
lines changed
  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

2 files changed

+39
-61
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ async def _init_session(
585585
):
586586
"""Initializes the session, and returns the session id."""
587587
from google.adk.events.event import Event
588+
import random
588589

589590
session_state = None
590591
if request.authorizations:
@@ -593,9 +594,14 @@ async def _init_session(
593594
auth = _Authorization(**auth)
594595
session_state[f"temp:{auth_id}"] = auth.access_token
595596

597+
if request.session_id:
598+
session_id = request.session_id
599+
else:
600+
session_id = f"temp_session_{random.randbytes(8).hex()}"
596601
session = await session_service.create_session(
597602
app_name=self._tmpl_attrs.get("app_name"),
598603
user_id=request.user_id,
604+
session_id=session_id,
599605
state=session_state,
600606
)
601607
if not session:
@@ -613,7 +619,7 @@ async def _init_session(
613619
saved_version = await artifact_service.save_artifact(
614620
app_name=self._tmpl_attrs.get("app_name"),
615621
user_id=request.user_id,
616-
session_id=session.id,
622+
session_id=session_id,
617623
filename=artifact.file_name,
618624
artifact=version_data.data,
619625
)
@@ -1052,61 +1058,43 @@ async def streaming_agent_run_with_events(self, request_json: str):
10521058

10531059
import json
10541060
from google.genai import types
1055-
from google.genai.errors import ClientError
10561061

10571062
request = _StreamRunRequest(**json.loads(request_json))
10581063
if not self._tmpl_attrs.get("in_memory_runner"):
10591064
self.set_up()
1060-
if not self._tmpl_attrs.get("runner"):
1061-
self.set_up()
10621065
# Prepare the in-memory session.
10631066
if not self._tmpl_attrs.get("in_memory_artifact_service"):
10641067
self.set_up()
1065-
if not self._tmpl_attrs.get("artifact_service"):
1066-
self.set_up()
10671068
if not self._tmpl_attrs.get("in_memory_session_service"):
10681069
self.set_up()
1069-
if not self._tmpl_attrs.get("session_service"):
1070-
self.set_up()
1070+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1071+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
10711072
app = self._tmpl_attrs.get("app")
1072-
10731073
# Try to get the session, if it doesn't exist, create a new one.
1074+
session = None
10741075
if request.session_id:
1075-
session_service = self._tmpl_attrs.get("session_service")
1076-
artifact_service = self._tmpl_attrs.get("artifact_service")
1077-
runner = self._tmpl_attrs.get("runner")
10781076
try:
10791077
session = await session_service.get_session(
10801078
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10811079
user_id=request.user_id,
10821080
session_id=request.session_id,
10831081
)
1084-
except ClientError:
1085-
# Fall back to create session if the session is not found.
1086-
# Specifying session_id on creation is not supported,
1087-
# so session id will be regenerated.
1088-
session = await self._init_session(
1089-
session_service=session_service,
1090-
artifact_service=artifact_service,
1091-
request=request,
1092-
)
1093-
else:
1094-
# Not providing a session ID will create a new in-memory session.
1095-
session_service = self._tmpl_attrs.get("in_memory_session_service")
1096-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1097-
runner = self._tmpl_attrs.get("in_memory_runner")
1098-
session = await session_service.create_session(
1099-
app_name=self._tmpl_attrs.get("app_name"),
1100-
user_id=request.user_id,
1101-
session_id=request.session_id,
1082+
except RuntimeError:
1083+
pass
1084+
if not session:
1085+
# Fall back to create session if the session is not found.
1086+
session = await self._init_session(
1087+
session_service=session_service,
1088+
artifact_service=artifact_service,
1089+
request=request,
11021090
)
11031091
if not session:
11041092
raise RuntimeError("Session initialization failed.")
11051093

11061094
# Run the agent
11071095
message_for_agent = types.Content(**request.message)
11081096
try:
1109-
async for event in runner.run_async(
1097+
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
11101098
user_id=request.user_id,
11111099
session_id=session.id,
11121100
new_message=message_for_agent,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ async def _init_session(
533533
):
534534
"""Initializes the session, and returns the session id."""
535535
from google.adk.events.event import Event
536+
import random
536537

537538
session_state = None
538539
if request.authorizations:
@@ -541,9 +542,14 @@ async def _init_session(
541542
auth = _Authorization(**auth)
542543
session_state[f"temp:{auth_id}"] = auth.access_token
543544

545+
if request.session_id:
546+
session_id = request.session_id
547+
else:
548+
session_id = f"temp_session_{random.randbytes(8).hex()}"
544549
session = await session_service.create_session(
545550
app_name=self._tmpl_attrs.get("app_name"),
546551
user_id=request.user_id,
552+
session_id=session_id,
547553
state=session_state,
548554
)
549555
if not session:
@@ -561,7 +567,7 @@ async def _init_session(
561567
saved_version = await artifact_service.save_artifact(
562568
app_name=self._tmpl_attrs.get("app_name"),
563569
user_id=request.user_id,
564-
session_id=session.id,
570+
session_id=session_id,
565571
filename=artifact.file_name,
566572
artifact=version_data.data,
567573
)
@@ -939,60 +945,44 @@ async def async_stream_query(
939945
def streaming_agent_run_with_events(self, request_json: str):
940946
import json
941947
from google.genai import types
942-
from google.genai.errors import ClientError
943948

944949
event_queue = queue.Queue(maxsize=1)
945950

946951
async def _invoke_agent_async():
947952
request = _StreamRunRequest(**json.loads(request_json))
948953
if not self._tmpl_attrs.get("in_memory_runner"):
949954
self.set_up()
950-
if not self._tmpl_attrs.get("runner"):
951-
self.set_up()
952955
# Prepare the in-memory session.
953956
if not self._tmpl_attrs.get("in_memory_artifact_service"):
954957
self.set_up()
955-
if not self._tmpl_attrs.get("artifact_service"):
956-
self.set_up()
957958
if not self._tmpl_attrs.get("in_memory_session_service"):
958959
self.set_up()
959-
if not self._tmpl_attrs.get("session_service"):
960-
self.set_up()
960+
session_service = self._tmpl_attrs.get("in_memory_session_service")
961+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
962+
# Try to get the session, if it doesn't exist, create a new one.
963+
session = None
961964
if request.session_id:
962-
session_service = self._tmpl_attrs.get("session_service")
963-
artifact_service = self._tmpl_attrs.get("artifact_service")
964-
runner = self._tmpl_attrs.get("runner")
965965
try:
966966
session = await session_service.get_session(
967967
app_name=self._tmpl_attrs.get("app_name"),
968968
user_id=request.user_id,
969969
session_id=request.session_id,
970970
)
971-
except ClientError:
972-
# Fall back to create session if the session is not found.
973-
# Specifying session_id on creation is not supported,
974-
# so session id will be regenerated.
975-
session = await self._init_session(
976-
session_service=session_service,
977-
artifact_service=artifact_service,
978-
request=request,
979-
)
980-
else:
981-
# Not providing a session ID will create a new in-memory session.
982-
session_service = self._tmpl_attrs.get("in_memory_session_service")
983-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
984-
runner = self._tmpl_attrs.get("in_memory_runner")
985-
session = await session_service.create_session(
986-
app_name=self._tmpl_attrs.get("app_name"),
987-
user_id=request.user_id,
988-
session_id=request.session_id,
971+
except RuntimeError:
972+
pass
973+
if not session:
974+
# Fall back to create session if the session is not found.
975+
session = await self._init_session(
976+
session_service=session_service,
977+
artifact_service=artifact_service,
978+
request=request,
989979
)
990980
if not session:
991981
raise RuntimeError("Session initialization failed.")
992982
# Run the agent.
993983
message_for_agent = types.Content(**request.message)
994984
try:
995-
for event in runner.run(
985+
for event in self._tmpl_attrs.get("in_memory_runner").run(
996986
user_id=request.user_id,
997987
session_id=session.id,
998988
new_message=message_for_agent,

0 commit comments

Comments
 (0)