Skip to content

Commit a0b6919

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Support the query method to handle types.ChatResponse response in Llama index template.
PiperOrigin-RevId: 741312199
1 parent 4c0293d commit a0b6919

File tree

3 files changed

+136
-12
lines changed

3 files changed

+136
-12
lines changed

tests/unit/vertex_llama_index/test_reasoning_engine_templates_llama_index.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@
1313
# limitations under the License.
1414
#
1515
import importlib
16-
from unittest import mock
1716
import json
17+
from unittest import mock
1818

1919
from google import auth
2020
import vertexai
2121
from google.cloud.aiplatform import initializer
2222
from vertexai.preview.reasoning_engines.templates import llama_index
2323
from vertexai.reasoning_engines import _utils
24-
import pytest
2524

2625
from llama_index.core import prompts
2726
from llama_index.core.base.llms import types
2827

28+
import pytest
29+
2930
_TEST_LOCATION = "us-central1"
3031
_TEST_PROJECT = "test-project"
3132
_TEST_MODEL = "gemini-1.0-pro"
@@ -232,3 +233,88 @@ def test_enable_tracing_warning(self, caplog, llama_index_instrumentor_none_mock
232233
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
233234
# agent.set_up()
234235
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
236+
237+
238+
class TestToJsonSerializableLlamaIndexObject:
239+
"""Tests for `_utils.to_json_serializable_llama_index_object`."""
240+
241+
def test_llama_index_response(self):
242+
mock_response: _utils.LlamaIndexResponse = mock.Mock(
243+
spec=_utils.LlamaIndexResponse
244+
)
245+
mock_response.response = "test response"
246+
mock_response.source_nodes = [
247+
mock.Mock(
248+
spec=_utils.LlamaIndexBaseModel,
249+
model_dump_json=lambda: '{"name": "model1"}',
250+
),
251+
mock.Mock(
252+
spec=_utils.LlamaIndexBaseModel,
253+
model_dump_json=lambda: '{"name": "model2"}',
254+
),
255+
]
256+
mock_response.metadata = {"key": "value"}
257+
258+
want = {
259+
"response": "test response",
260+
"source_nodes": ['{"name": "model1"}', '{"name": "model2"}'],
261+
"metadata": {"key": "value"},
262+
}
263+
got = _utils.to_json_serializable_llama_index_object(mock_response)
264+
assert got == want
265+
266+
def test_llama_index_chat_response(self):
267+
mock_chat_response: _utils.LlamaIndexChatResponse = mock.Mock(
268+
spec=_utils.LlamaIndexChatResponse
269+
)
270+
mock_chat_response.message = mock.Mock(
271+
spec=_utils.LlamaIndexBaseModel,
272+
model_dump_json=lambda: '{"content": "chat message"}',
273+
)
274+
275+
want = {"content": "chat message"}
276+
got = _utils.to_json_serializable_llama_index_object(mock_chat_response)
277+
assert got == want
278+
279+
def test_llama_index_base_model(self):
280+
mock_base_model: _utils.LlamaIndexBaseModel = mock.Mock(
281+
spec=_utils.LlamaIndexBaseModel
282+
)
283+
mock_base_model.model_dump_json = lambda: '{"name": "test_model"}'
284+
285+
want = {"name": "test_model"}
286+
got = _utils.to_json_serializable_llama_index_object(mock_base_model)
287+
assert got == want
288+
289+
def test_sequence_of_llama_index_base_model(self):
290+
mock_base_model1: _utils.LlamaIndexBaseModel = mock.Mock(
291+
spec=_utils.LlamaIndexBaseModel
292+
)
293+
mock_base_model1.model_dump_json = lambda: '{"name": "test_model1"}'
294+
mock_base_model2: _utils.LlamaIndexBaseModel = mock.Mock(
295+
spec=_utils.LlamaIndexBaseModel
296+
)
297+
mock_base_model2.model_dump_json = lambda: '{"name": "test_model2"}'
298+
mock_base_model_list = [mock_base_model1, mock_base_model2]
299+
300+
want = [{"name": "test_model1"}, {"name": "test_model2"}]
301+
got = _utils.to_json_serializable_llama_index_object(mock_base_model_list)
302+
assert got == want
303+
304+
def test_sequence_of_mixed_types(self):
305+
mock_base_model: _utils.LlamaIndexBaseModel = mock.Mock(
306+
spec=_utils.LlamaIndexBaseModel
307+
)
308+
mock_base_model.model_dump_json = lambda: '{"name": "test_model"}'
309+
mock_string = "test_string"
310+
mock_list = [mock_base_model, mock_string]
311+
312+
want = [{"name": "test_model"}, "test_string"]
313+
got = _utils.to_json_serializable_llama_index_object(mock_list)
314+
assert got == want
315+
316+
def test_other_type(self):
317+
test_dict = {"name": "test_model"}
318+
want = "{'name': 'test_model'}"
319+
got = _utils.to_json_serializable_llama_index_object(test_dict)
320+
assert got == want

vertexai/preview/reasoning_engines/templates/llama_index.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def query(
523523
self,
524524
input: Union[str, Mapping[str, Any]],
525525
**kwargs: Any,
526-
) -> Union[Dict[str, Any], Sequence[Dict[str, Any]]]:
526+
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
527527
"""Queries the Agent with the given input and config.
528528
529529
Args:
@@ -536,19 +536,14 @@ def query(
536536
Returns:
537537
The output of querying the Agent with the given input and config.
538538
"""
539-
import json
540539
from vertexai.reasoning_engines import _utils
541-
from llama_index.core.base.response import schema
542540

543541
if isinstance(input, str):
544542
input = {"input": input}
545543

546544
if not self._runnable:
547545
self.set_up()
548546

549-
response = self._runnable.run(**input, **kwargs)
550-
if isinstance(response, schema.Response):
551-
return _utils.llama_index_response_to_dict(response)
552-
if isinstance(response, Sequence):
553-
return [json.loads(r.model_dump_json()) for r in response]
554-
return json.loads(response.model_dump_json())
547+
return _utils.to_json_serializable_llama_index_object(
548+
self._runnable.run(**input, **kwargs)
549+
)

vertexai/reasoning_engines/_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@
3939

4040
try:
4141
from llama_index.core.base.response import schema as llama_index_schema
42+
from llama_index.core.base.llms import types as llama_index_types
4243

4344
LlamaIndexResponse = llama_index_schema.Response
45+
LlamaIndexBaseModel = llama_index_schema.BaseModel
46+
LlamaIndexChatResponse = llama_index_types.ChatResponse
4447
except ImportError:
4548
LlamaIndexResponse = Any
49+
LlamaIndexBaseModel = Any
50+
LlamaIndexChatResponse = Any
4651

4752
JsonDict = Dict[str, Any]
4853

@@ -111,7 +116,7 @@ def dataclass_to_dict(obj: dataclasses.dataclass) -> JsonDict:
111116
return json.loads(json.dumps(dataclasses.asdict(obj)))
112117

113118

114-
def llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]:
119+
def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]:
115120
response = {}
116121
if hasattr(obj, "response"):
117122
response["response"] = obj.response
@@ -123,6 +128,44 @@ def llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]:
123128
return json.loads(json.dumps(response))
124129

125130

131+
def _llama_index_chat_response_to_dict(
132+
obj: LlamaIndexChatResponse,
133+
) -> Dict[str, Any]:
134+
return json.loads(obj.message.model_dump_json())
135+
136+
137+
def _llama_index_base_model_to_dict(
138+
obj: LlamaIndexBaseModel,
139+
) -> Dict[str, Any]:
140+
return json.loads(obj.model_dump_json())
141+
142+
143+
def to_json_serializable_llama_index_object(
144+
obj: Union[
145+
LlamaIndexResponse,
146+
LlamaIndexBaseModel,
147+
LlamaIndexChatResponse,
148+
Sequence[LlamaIndexBaseModel],
149+
]
150+
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
151+
"""Converts a LlamaIndexResponse to a JSON serializable object."""
152+
if isinstance(obj, LlamaIndexResponse):
153+
return _llama_index_response_to_dict(obj)
154+
if isinstance(obj, LlamaIndexChatResponse):
155+
return _llama_index_chat_response_to_dict(obj)
156+
if isinstance(obj, Sequence):
157+
seq_result = []
158+
for item in obj:
159+
if isinstance(item, LlamaIndexBaseModel):
160+
seq_result.append(_llama_index_base_model_to_dict(item))
161+
continue
162+
seq_result.append(str(item))
163+
return seq_result
164+
if isinstance(obj, LlamaIndexBaseModel):
165+
return _llama_index_base_model_to_dict(obj)
166+
return str(obj)
167+
168+
126169
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
127170
"""Converts the contents of the httpbody message to JSON format.
128171

0 commit comments

Comments
 (0)