Skip to content

Commit 837c8ea

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client(sessions): Add label to Sessions
PiperOrigin-RevId: 835253647
1 parent bc26160 commit 837c8ea

File tree

3 files changed

+104
-71
lines changed

3 files changed

+104
-71
lines changed

tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,63 +22,69 @@
2222

2323
def test_create_session_with_ttl(client):
2424
agent_engine = client.agent_engines.create()
25-
assert isinstance(agent_engine, types.AgentEngine)
26-
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
25+
try:
26+
assert isinstance(agent_engine, types.AgentEngine)
27+
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
2728

28-
operation = client.agent_engines.create_session(
29-
name=agent_engine.api_resource.name,
30-
user_id="test-user-123",
31-
config=types.CreateAgentEngineSessionConfig(
32-
display_name="my_session",
33-
session_state={"foo": "bar"},
34-
ttl="120s",
35-
),
36-
)
37-
assert isinstance(operation, types.AgentEngineSessionOperation)
38-
assert operation.response.display_name == "my_session"
39-
assert operation.response.session_state == {"foo": "bar"}
40-
assert operation.response.user_id == "test-user-123"
41-
assert operation.response.name.startswith(agent_engine.api_resource.name)
42-
# Expire time is calculated by the server, so we only check that it is
43-
# within a reasonable range to avoid flakiness.
44-
assert (
45-
operation.response.create_time + datetime.timedelta(seconds=119.5)
46-
<= operation.response.expire_time
47-
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
48-
)
49-
# Clean up resources.
50-
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
29+
operation = client.agent_engines.create_session(
30+
name=agent_engine.api_resource.name,
31+
user_id="test-user-123",
32+
config=types.CreateAgentEngineSessionConfig(
33+
display_name="my_session",
34+
session_state={"foo": "bar"},
35+
ttl="120s",
36+
labels={"label_key": "label_value"},
37+
),
38+
)
39+
assert isinstance(operation, types.AgentEngineSessionOperation)
40+
assert operation.response.display_name == "my_session"
41+
assert operation.response.session_state == {"foo": "bar"}
42+
assert operation.response.user_id == "test-user-123"
43+
assert operation.response.labels == {"label_key": "label_value"}
44+
assert operation.response.name.startswith(agent_engine.api_resource.name)
45+
# Expire time is calculated by the server, so we only check that it is
46+
# within a reasonable range to avoid flakiness.
47+
assert (
48+
operation.response.create_time + datetime.timedelta(seconds=119.5)
49+
<= operation.response.expire_time
50+
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
51+
)
52+
finally:
53+
# Clean up resources.
54+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
5155

5256

5357
def test_create_session_with_expire_time(client):
5458
agent_engine = client.agent_engines.create()
55-
assert isinstance(agent_engine, types.AgentEngine)
56-
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
57-
expire_time = datetime.datetime(
58-
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
59-
)
59+
try:
60+
assert isinstance(agent_engine, types.AgentEngine)
61+
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
62+
expire_time = datetime.datetime(
63+
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
64+
)
6065

61-
operation = client.agent_engines.sessions.create(
62-
name=agent_engine.api_resource.name,
63-
user_id="test-user-123",
64-
config=types.CreateAgentEngineSessionConfig(
65-
display_name="my_session",
66-
session_state={"foo": "bar"},
67-
expire_time=expire_time,
68-
),
69-
)
70-
assert isinstance(operation, types.AgentEngineSessionOperation)
71-
assert operation.response.display_name == "my_session"
72-
assert operation.response.session_state == {"foo": "bar"}
73-
assert operation.response.user_id == "test-user-123"
74-
assert operation.response.name.startswith(agent_engine.api_resource.name)
75-
assert operation.response.expire_time == expire_time
76-
# Clean up resources.
77-
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
66+
operation = client.agent_engines.sessions.create(
67+
name=agent_engine.api_resource.name,
68+
user_id="test-user-123",
69+
config=types.CreateAgentEngineSessionConfig(
70+
display_name="my_session",
71+
session_state={"foo": "bar"},
72+
expire_time=expire_time,
73+
),
74+
)
75+
assert isinstance(operation, types.AgentEngineSessionOperation)
76+
assert operation.response.display_name == "my_session"
77+
assert operation.response.session_state == {"foo": "bar"}
78+
assert operation.response.user_id == "test-user-123"
79+
assert operation.response.name.startswith(agent_engine.api_resource.name)
80+
assert operation.response.expire_time == expire_time
81+
finally:
82+
# Clean up resources.
83+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
7884

7985

8086
pytestmark = pytest_helper.setup(
8187
file=__file__,
8288
globals_for_file=globals(),
83-
test_method="agent_engines.create_session",
89+
test_method="agent_engines.sessions.create",
8490
)

vertexai/_genai/sessions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def _CreateAgentEngineSessionConfig_to_vertex(
5555
if getv(from_object, ["expire_time"]) is not None:
5656
setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"]))
5757

58+
if getv(from_object, ["labels"]) is not None:
59+
setv(parent_object, ["labels"], getv(from_object, ["labels"]))
60+
5861
return to_object
5962

6063

@@ -181,6 +184,9 @@ def _UpdateAgentEngineSessionConfig_to_vertex(
181184
if getv(from_object, ["expire_time"]) is not None:
182185
setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"]))
183186

187+
if getv(from_object, ["labels"]) is not None:
188+
setv(parent_object, ["labels"], getv(from_object, ["labels"]))
189+
184190
if getv(from_object, ["update_mask"]) is not None:
185191
setv(
186192
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])

vertexai/_genai/types/common.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8778,6 +8778,10 @@ class CreateAgentEngineSessionConfig(_common.BaseModel):
87788778
default=None,
87798779
description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""",
87808780
)
8781+
labels: Optional[dict[str, str]] = Field(
8782+
default=None,
8783+
description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""",
8784+
)
87818785

87828786

87838787
class CreateAgentEngineSessionConfigDict(TypedDict, total=False):
@@ -8803,6 +8807,9 @@ class CreateAgentEngineSessionConfigDict(TypedDict, total=False):
88038807
expire_time: Optional[datetime.datetime]
88048808
"""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input."""
88058809

8810+
labels: Optional[dict[str, str]]
8811+
"""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels."""
8812+
88068813

88078814
CreateAgentEngineSessionConfigOrDict = Union[
88088815
CreateAgentEngineSessionConfig, CreateAgentEngineSessionConfigDict
@@ -8846,32 +8853,36 @@ class _CreateAgentEngineSessionRequestParametersDict(TypedDict, total=False):
88468853
class Session(_common.BaseModel):
88478854
"""A session."""
88488855

8849-
create_time: Optional[datetime.datetime] = Field(
8850-
default=None,
8851-
description="""Output only. Timestamp when the session was created.""",
8852-
)
8853-
display_name: Optional[str] = Field(
8854-
default=None, description="""Optional. The display name of the session."""
8855-
)
88568856
expire_time: Optional[datetime.datetime] = Field(
88578857
default=None,
88588858
description="""Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input.""",
88598859
)
8860+
ttl: Optional[str] = Field(
8861+
default=None, description="""Optional. Input only. The TTL for this session."""
8862+
)
88608863
name: Optional[str] = Field(
88618864
default=None,
88628865
description="""Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'.""",
88638866
)
8864-
session_state: Optional[dict[str, Any]] = Field(
8867+
create_time: Optional[datetime.datetime] = Field(
88658868
default=None,
8866-
description="""Optional. Session specific memory which stores key conversation points.""",
8867-
)
8868-
ttl: Optional[str] = Field(
8869-
default=None, description="""Optional. Input only. The TTL for this session."""
8869+
description="""Output only. Timestamp when the session was created.""",
88708870
)
88718871
update_time: Optional[datetime.datetime] = Field(
88728872
default=None,
88738873
description="""Output only. Timestamp when the session was updated.""",
88748874
)
8875+
display_name: Optional[str] = Field(
8876+
default=None, description="""Optional. The display name of the session."""
8877+
)
8878+
labels: Optional[dict[str, str]] = Field(
8879+
default=None,
8880+
description="""The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""",
8881+
)
8882+
session_state: Optional[dict[str, Any]] = Field(
8883+
default=None,
8884+
description="""Optional. Session specific memory which stores key conversation points.""",
8885+
)
88758886
user_id: Optional[str] = Field(
88768887
default=None,
88778888
description="""Required. Immutable. String id provided by the user""",
@@ -8881,27 +8892,30 @@ class Session(_common.BaseModel):
88818892
class SessionDict(TypedDict, total=False):
88828893
"""A session."""
88838894

8884-
create_time: Optional[datetime.datetime]
8885-
"""Output only. Timestamp when the session was created."""
8886-
8887-
display_name: Optional[str]
8888-
"""Optional. The display name of the session."""
8889-
88908895
expire_time: Optional[datetime.datetime]
88918896
"""Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input."""
88928897

8898+
ttl: Optional[str]
8899+
"""Optional. Input only. The TTL for this session."""
8900+
88938901
name: Optional[str]
88948902
"""Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'."""
88958903

8896-
session_state: Optional[dict[str, Any]]
8897-
"""Optional. Session specific memory which stores key conversation points."""
8898-
8899-
ttl: Optional[str]
8900-
"""Optional. Input only. The TTL for this session."""
8904+
create_time: Optional[datetime.datetime]
8905+
"""Output only. Timestamp when the session was created."""
89018906

89028907
update_time: Optional[datetime.datetime]
89038908
"""Output only. Timestamp when the session was updated."""
89048909

8910+
display_name: Optional[str]
8911+
"""Optional. The display name of the session."""
8912+
8913+
labels: Optional[dict[str, str]]
8914+
"""The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels."""
8915+
8916+
session_state: Optional[dict[str, Any]]
8917+
"""Optional. Session specific memory which stores key conversation points."""
8918+
89058919
user_id: Optional[str]
89068920
"""Required. Immutable. String id provided by the user"""
89078921

@@ -9240,6 +9254,10 @@ class UpdateAgentEngineSessionConfig(_common.BaseModel):
92409254
default=None,
92419255
description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""",
92429256
)
9257+
labels: Optional[dict[str, str]] = Field(
9258+
default=None,
9259+
description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""",
9260+
)
92439261
update_mask: Optional[str] = Field(
92449262
default=None,
92459263
description="""The update mask to apply. For the `FieldMask` definition, see
@@ -9273,6 +9291,9 @@ class UpdateAgentEngineSessionConfigDict(TypedDict, total=False):
92739291
expire_time: Optional[datetime.datetime]
92749292
"""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input."""
92759293

9294+
labels: Optional[dict[str, str]]
9295+
"""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels."""
9296+
92769297
update_mask: Optional[str]
92779298
"""The update mask to apply. For the `FieldMask` definition, see
92789299
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""

0 commit comments

Comments
 (0)