@@ -533,6 +533,7 @@ async def _init_session(
533533 ):
534534 """Initializes the session, and returns the session id."""
535535 from google .adk .events .event import Event
536+ import random
536537
537538 session_state = None
538539 if request .authorizations :
@@ -541,9 +542,14 @@ async def _init_session(
541542 auth = _Authorization (** auth )
542543 session_state [f"temp:{ auth_id } " ] = auth .access_token
543544
545+ if request .session_id :
546+ session_id = request .session_id
547+ else :
548+ session_id = f"temp_session_{ random .randbytes (8 ).hex ()} "
544549 session = await session_service .create_session (
545550 app_name = self ._tmpl_attrs .get ("app_name" ),
546551 user_id = request .user_id ,
552+ session_id = session_id ,
547553 state = session_state ,
548554 )
549555 if not session :
@@ -561,7 +567,7 @@ async def _init_session(
561567 saved_version = await artifact_service .save_artifact (
562568 app_name = self ._tmpl_attrs .get ("app_name" ),
563569 user_id = request .user_id ,
564- session_id = session . id ,
570+ session_id = session_id ,
565571 filename = artifact .file_name ,
566572 artifact = version_data .data ,
567573 )
@@ -939,60 +945,44 @@ async def async_stream_query(
939945 def streaming_agent_run_with_events (self , request_json : str ):
940946 import json
941947 from google .genai import types
942- from google .genai .errors import ClientError
943948
944949 event_queue = queue .Queue (maxsize = 1 )
945950
946951 async def _invoke_agent_async ():
947952 request = _StreamRunRequest (** json .loads (request_json ))
948953 if not self ._tmpl_attrs .get ("in_memory_runner" ):
949954 self .set_up ()
950- if not self ._tmpl_attrs .get ("runner" ):
951- self .set_up ()
952955 # Prepare the in-memory session.
953956 if not self ._tmpl_attrs .get ("in_memory_artifact_service" ):
954957 self .set_up ()
955- if not self ._tmpl_attrs .get ("artifact_service" ):
956- self .set_up ()
957958 if not self ._tmpl_attrs .get ("in_memory_session_service" ):
958959 self .set_up ()
959- if not self ._tmpl_attrs .get ("session_service" ):
960- self .set_up ()
960+ session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
961+ artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
962+ # Try to get the session, if it doesn't exist, create a new one.
963+ session = None
961964 if request .session_id :
962- session_service = self ._tmpl_attrs .get ("session_service" )
963- artifact_service = self ._tmpl_attrs .get ("artifact_service" )
964- runner = self ._tmpl_attrs .get ("runner" )
965965 try :
966966 session = await session_service .get_session (
967967 app_name = self ._tmpl_attrs .get ("app_name" ),
968968 user_id = request .user_id ,
969969 session_id = request .session_id ,
970970 )
971- except ClientError :
972- # Fall back to create session if the session is not found.
973- # Specifying session_id on creation is not supported,
974- # so session id will be regenerated.
975- session = await self ._init_session (
976- session_service = session_service ,
977- artifact_service = artifact_service ,
978- request = request ,
979- )
980- else :
981- # Not providing a session ID will create a new in-memory session.
982- session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
983- artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
984- runner = self ._tmpl_attrs .get ("in_memory_runner" )
985- session = await session_service .create_session (
986- app_name = self ._tmpl_attrs .get ("app_name" ),
987- user_id = request .user_id ,
988- session_id = request .session_id ,
971+ except RuntimeError :
972+ pass
973+ if not session :
974+ # Fall back to create session if the session is not found.
975+ session = await self ._init_session (
976+ session_service = session_service ,
977+ artifact_service = artifact_service ,
978+ request = request ,
989979 )
990980 if not session :
991981 raise RuntimeError ("Session initialization failed." )
992982 # Run the agent.
993983 message_for_agent = types .Content (** request .message )
994984 try :
995- for event in runner .run (
985+ for event in self . _tmpl_attrs . get ( "in_memory_runner" ) .run (
996986 user_id = request .user_id ,
997987 session_id = session .id ,
998988 new_message = message_for_agent ,
0 commit comments