Skip to content

Commit d3b12d5

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
feat: Reenable VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 829503268
1 parent 31be512 commit d3b12d5

File tree

2 files changed

+52
-30
lines changed
  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

2 files changed

+52
-30
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,6 @@ async def _init_session(
602602
):
603603
"""Initializes the session, and returns the session id."""
604604
from google.adk.events.event import Event
605-
import random
606605

607606
session_state = None
608607
if request.authorizations:
@@ -611,14 +610,9 @@ async def _init_session(
611610
auth = _Authorization(**auth)
612611
session_state[auth_id] = auth.access_token
613612

614-
if request.session_id:
615-
session_id = request.session_id
616-
else:
617-
session_id = f"temp_session_{random.randbytes(8).hex()}"
618613
session = await session_service.create_session(
619614
app_name=self._tmpl_attrs.get("app_name"),
620615
user_id=request.user_id,
621-
session_id=session_id,
622616
state=session_state,
623617
)
624618
if not session:
@@ -636,7 +630,7 @@ async def _init_session(
636630
saved_version = await artifact_service.save_artifact(
637631
app_name=self._tmpl_attrs.get("app_name"),
638632
user_id=request.user_id,
639-
session_id=session_id,
633+
session_id=session.id,
640634
filename=artifact.file_name,
641635
artifact=version_data.data,
642636
)
@@ -1078,31 +1072,49 @@ async def streaming_agent_run_with_events(self, request_json: str):
10781072

10791073
import json
10801074
from google.genai import types
1075+
from google.genai.errors import ClientError
10811076

10821077
request = _StreamRunRequest(**json.loads(request_json))
10831078
if not self._tmpl_attrs.get("in_memory_runner"):
10841079
self.set_up()
1080+
if not self._tmpl_attrs.get("runner"):
1081+
self.set_up()
10851082
# Prepare the in-memory session.
10861083
if not self._tmpl_attrs.get("in_memory_artifact_service"):
10871084
self.set_up()
1085+
if not self._tmpl_attrs.get("artifact_service"):
1086+
self.set_up()
10881087
if not self._tmpl_attrs.get("in_memory_session_service"):
10891088
self.set_up()
1090-
session_service = self._tmpl_attrs.get("in_memory_session_service")
1091-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1089+
if not self._tmpl_attrs.get("session_service"):
1090+
self.set_up()
10921091
app = self._tmpl_attrs.get("app")
1092+
10931093
# Try to get the session, if it doesn't exist, create a new one.
1094-
session = None
10951094
if request.session_id:
1095+
session_service = self._tmpl_attrs.get("session_service")
1096+
artifact_service = self._tmpl_attrs.get("artifact_service")
1097+
runner = self._tmpl_attrs.get("runner")
10961098
try:
10971099
session = await session_service.get_session(
10981100
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10991101
user_id=request.user_id,
11001102
session_id=request.session_id,
11011103
)
1102-
except RuntimeError:
1103-
pass
1104-
if not session:
1105-
# Fall back to create session if the session is not found.
1104+
except ClientError:
1105+
# Fall back to create session if the session is not found.
1106+
# Specifying session_id on creation is not supported,
1107+
# so session id will be regenerated.
1108+
session = await self._init_session(
1109+
session_service=session_service,
1110+
artifact_service=artifact_service,
1111+
request=request,
1112+
)
1113+
else:
1114+
# Not providing a session ID will create a new in-memory session.
1115+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1116+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1117+
runner = self._tmpl_attrs.get("in_memory_runner")
11061118
session = await self._init_session(
11071119
session_service=session_service,
11081120
artifact_service=artifact_service,
@@ -1114,7 +1126,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
11141126
# Run the agent
11151127
message_for_agent = types.Content(**request.message)
11161128
try:
1117-
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
1129+
async for event in runner.run_async(
11181130
user_id=request.user_id,
11191131
session_id=session.id,
11201132
new_message=message_for_agent,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,6 @@ async def _init_session(
550550
):
551551
"""Initializes the session, and returns the session id."""
552552
from google.adk.events.event import Event
553-
import random
554553

555554
session_state = None
556555
if request.authorizations:
@@ -559,14 +558,9 @@ async def _init_session(
559558
auth = _Authorization(**auth)
560559
session_state[auth_id] = auth.access_token
561560

562-
if request.session_id:
563-
session_id = request.session_id
564-
else:
565-
session_id = f"temp_session_{random.randbytes(8).hex()}"
566561
session = await session_service.create_session(
567562
app_name=self._tmpl_attrs.get("app_name"),
568563
user_id=request.user_id,
569-
session_id=session_id,
570564
state=session_state,
571565
)
572566
if not session:
@@ -584,7 +578,7 @@ async def _init_session(
584578
saved_version = await artifact_service.save_artifact(
585579
app_name=self._tmpl_attrs.get("app_name"),
586580
user_id=request.user_id,
587-
session_id=session_id,
581+
session_id=session.id,
588582
filename=artifact.file_name,
589583
artifact=version_data.data,
590584
)
@@ -965,33 +959,49 @@ async def async_stream_query(
965959
def streaming_agent_run_with_events(self, request_json: str):
966960
import json
967961
from google.genai import types
962+
from google.genai.errors import ClientError
968963

969964
event_queue = queue.Queue(maxsize=1)
970965

971966
async def _invoke_agent_async():
972967
request = _StreamRunRequest(**json.loads(request_json))
973968
if not self._tmpl_attrs.get("in_memory_runner"):
974969
self.set_up()
970+
if not self._tmpl_attrs.get("runner"):
971+
self.set_up()
975972
# Prepare the in-memory session.
976973
if not self._tmpl_attrs.get("in_memory_artifact_service"):
977974
self.set_up()
975+
if not self._tmpl_attrs.get("artifact_service"):
976+
self.set_up()
978977
if not self._tmpl_attrs.get("in_memory_session_service"):
979978
self.set_up()
980-
session_service = self._tmpl_attrs.get("in_memory_session_service")
981-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
979+
if not self._tmpl_attrs.get("session_service"):
980+
self.set_up()
981+
982982
# Try to get the session, if it doesn't exist, create a new one.
983-
session = None
984983
if request.session_id:
984+
session_service = self._tmpl_attrs.get("session_service")
985+
artifact_service = self._tmpl_attrs.get("artifact_service")
986+
runner = self._tmpl_attrs.get("runner")
985987
try:
986988
session = await session_service.get_session(
987989
app_name=self._tmpl_attrs.get("app_name"),
988990
user_id=request.user_id,
989991
session_id=request.session_id,
990992
)
991-
except RuntimeError:
992-
pass
993-
if not session:
994-
# Fall back to create session if the session is not found.
993+
except ClientError:
994+
# Fall back to create session if the session is not found.
995+
session = await self._init_session(
996+
session_service=session_service,
997+
artifact_service=artifact_service,
998+
request=request,
999+
)
1000+
else:
1001+
# Not providing a session ID will create a new in-memory session.
1002+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1003+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1004+
runner = self._tmpl_attrs.get("in_memory_runner")
9951005
session = await self._init_session(
9961006
session_service=session_service,
9971007
artifact_service=artifact_service,
@@ -1002,7 +1012,7 @@ async def _invoke_agent_async():
10021012
# Run the agent.
10031013
message_for_agent = types.Content(**request.message)
10041014
try:
1005-
for event in self._tmpl_attrs.get("in_memory_runner").run(
1015+
for event in runner.run_async(
10061016
user_id=request.user_id,
10071017
session_id=session.id,
10081018
new_message=message_for_agent,

0 commit comments

Comments
 (0)