Skip to content

Commit 0935a40

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add Bigquery Forecast tool
This tool answers questions about structured data in BigQuery using natural language. PiperOrigin-RevId: 805414952
1 parent 3b428ec commit 0935a40

File tree

5 files changed

+276
-1
lines changed

5 files changed

+276
-1
lines changed

contributing/samples/bigquery/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ distributed via the `google.adk.tools.bigquery` module. These tools include:
3535
the official [Conversational Analytics API documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview)
3636
for instructions.
3737

38+
1. `forecast`
39+
40+
Perform time series forecasting using BigQuery's `AI.FORECAST` function,
41+
leveraging the TimesFM 2.0 model.
42+
3843
## How to use
3944

4045
Set up environment variables in your `.env` file for using

src/google/adk/tools/bigquery/bigquery_toolset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ async def get_tools(
8181
metadata_tool.list_dataset_ids,
8282
metadata_tool.list_table_ids,
8383
query_tool.get_execute_sql(self._tool_settings),
84+
query_tool.forecast,
8485
data_insights_tool.ask_data_insights,
8586
]
8687
]

src/google/adk/tools/bigquery/query_tool.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import types
2020
from typing import Callable
21+
from typing import Optional
2122

2223
from google.auth.credentials import Credentials
2324
from google.cloud import bigquery
@@ -596,3 +597,169 @@ def get_execute_sql(settings: BigQueryToolConfig) -> Callable[..., dict]:
596597
execute_sql_wrapper.__doc__ = _execute_sql_write_mode.__doc__
597598

598599
return execute_sql_wrapper
600+
601+
602+
def forecast(
603+
project_id: str,
604+
history_data: str,
605+
timestamp_col: str,
606+
data_col: str,
607+
credentials: Credentials,
608+
settings: BigQueryToolConfig,
609+
tool_context: ToolContext,
610+
horizon: int = 10,
611+
id_cols: Optional[list[str]] = None,
612+
) -> dict:
613+
"""Run a BigQuery AI time series forecast using AI.FORECAST.
614+
615+
Args:
616+
project_id (str): The GCP project id in which the query should be
617+
executed.
618+
history_data (str): The table id of the BigQuery table containing the
619+
history time series data or a query statement that select the history
620+
data.
621+
timestamp_col (str): The name of the column containing the timestamp for
622+
each data point.
623+
data_col (str): The name of the column containing the numerical values to
624+
be forecasted.
625+
credentials (Credentials): The credentials to use for the request.
626+
settings (BigQueryToolConfig): The settings for the tool.
627+
tool_context (ToolContext): The context for the tool.
628+
horizon (int, optional): The number of time steps to forecast into the
629+
future. Defaults to 10.
630+
id_cols (list, optional): The column names of the id columns to indicate
631+
each time series when there are multiple time series in the table. All
632+
elements must be strings. Defaults to None.
633+
634+
Returns:
635+
dict: Dictionary representing the result of the forecast. The result
636+
contains the forecasted values along with prediction intervals.
637+
638+
Examples:
639+
Forecast daily sales for the next 7 days based on historical data from
640+
a BigQuery table:
641+
642+
>>> forecast(
643+
... project_id="my-gcp-project",
644+
... history_data="my-dataset.my-sales-table",
645+
... timestamp_col="sale_date",
646+
... data_col="daily_sales",
647+
... horizon=7
648+
... )
649+
{
650+
"status": "SUCCESS",
651+
"rows": [
652+
{
653+
"forecast_timestamp": "2025-01-08T00:00:00",
654+
"forecast_value": 12345.67,
655+
"confidence_level": 0.95,
656+
"prediction_interval_lower_bound": 11000.0,
657+
"prediction_interval_upper_bound": 13691.34,
658+
"ai_forecast_status": ""
659+
},
660+
...
661+
]
662+
}
663+
664+
Forecast multiple time series using a SQL query as input:
665+
666+
>>> history_query = (
667+
... "SELECT unique_id, timestamp, value "
668+
... "FROM `my-project.my-dataset.my-timeseries-table` "
669+
... "WHERE timestamp > '1980-01-01'"
670+
... )
671+
>>> forecast(
672+
... project_id="my-gcp-project",
673+
... history_data=history_query,
674+
... timestamp_col="timestamp",
675+
... data_col="value",
676+
... id_cols=["unique_id"],
677+
... horizon=14
678+
... )
679+
{
680+
"status": "SUCCESS",
681+
"rows": [
682+
{
683+
"unique_id": "T1",
684+
"forecast_timestamp": "1980-08-28T00:00:00",
685+
"forecast_value": 1253218.75,
686+
"confidence_level": 0.95,
687+
"prediction_interval_lower_bound": 274252.51,
688+
"prediction_interval_upper_bound": 2232184.99,
689+
"ai_forecast_status": ""
690+
},
691+
...
692+
]
693+
}
694+
695+
Error Scenarios:
696+
When an element in `id_cols` is not a string:
697+
698+
>>> forecast(
699+
... project_id="my-gcp-project",
700+
... history_data="my-dataset.my-sales-table",
701+
... timestamp_col="sale_date",
702+
... data_col="daily_sales",
703+
... id_cols=["store_id", 123]
704+
... )
705+
{
706+
"status": "ERROR",
707+
"error_details": "All elements in id_cols must be strings."
708+
}
709+
710+
When `history_data` refers to a table that does not exist:
711+
712+
>>> forecast(
713+
... project_id="my-gcp-project",
714+
... history_data="my-dataset.non-existent-table",
715+
... timestamp_col="sale_date",
716+
... data_col="daily_sales"
717+
... )
718+
{
719+
"status": "ERROR",
720+
"error_details": "Not found: Table
721+
my-gcp-project:my-dataset.non-existent-table was not found in
722+
location US"
723+
}
724+
"""
725+
model = "TimesFM 2.0"
726+
confidence_level = 0.95
727+
trimmed_upper_history_data = history_data.strip().upper()
728+
if trimmed_upper_history_data.startswith(
729+
"SELECT"
730+
) or trimmed_upper_history_data.startswith("WITH"):
731+
history_data_source = f"({history_data})"
732+
else:
733+
history_data_source = f"TABLE `{history_data}`"
734+
735+
if id_cols:
736+
if not all(isinstance(item, str) for item in id_cols):
737+
return {
738+
"status": "ERROR",
739+
"error_details": "All elements in id_cols must be strings.",
740+
}
741+
id_cols_str = "[" + ", ".join([f"'{col}'" for col in id_cols]) + "]"
742+
743+
query = f"""
744+
SELECT * FROM AI.FORECAST(
745+
{history_data_source},
746+
data_col => '{data_col}',
747+
timestamp_col => '{timestamp_col}',
748+
model => '{model}',
749+
id_cols => {id_cols_str},
750+
horizon => {horizon},
751+
confidence_level => {confidence_level}
752+
)
753+
"""
754+
else:
755+
query = f"""
756+
SELECT * FROM AI.FORECAST(
757+
{history_data_source},
758+
data_col => '{data_col}',
759+
timestamp_col => '{timestamp_col}',
760+
model => '{model}',
761+
horizon => {horizon},
762+
confidence_level => {confidence_level}
763+
)
764+
"""
765+
return execute_sql(project_id, query, credentials, settings, tool_context)

tests/unittests/tools/bigquery/test_bigquery_query_tool.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from google.adk.tools.bigquery.config import BigQueryToolConfig
3030
from google.adk.tools.bigquery.config import WriteMode
3131
from google.adk.tools.bigquery.query_tool import execute_sql
32+
from google.adk.tools.bigquery.query_tool import forecast
3233
from google.adk.tools.tool_context import ToolContext
3334
from google.auth.exceptions import DefaultCredentialsError
3435
from google.cloud import bigquery
@@ -1028,3 +1029,103 @@ def test_execute_sql_unexpected_project_id():
10281029
f" {compute_project_id}."
10291030
),
10301031
}
1032+
1033+
1034+
# AI.Forecast calls execute_sql with a specific query statement. We need to
1035+
# test that the query is properly constructed and call execute_sql with the
1036+
# correct parameters exactly once.
1037+
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
1038+
def test_forecast_with_table_id(mock_execute_sql):
1039+
mock_credentials = mock.MagicMock(spec=Credentials)
1040+
mock_settings = BigQueryToolConfig()
1041+
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
1042+
1043+
forecast(
1044+
project_id="test-project",
1045+
history_data="test-dataset.test-table",
1046+
timestamp_col="ts_col",
1047+
data_col="data_col",
1048+
credentials=mock_credentials,
1049+
settings=mock_settings,
1050+
tool_context=mock_tool_context,
1051+
horizon=20,
1052+
id_cols=["id1", "id2"],
1053+
)
1054+
1055+
expected_query = """
1056+
SELECT * FROM AI.FORECAST(
1057+
TABLE `test-dataset.test-table`,
1058+
data_col => 'data_col',
1059+
timestamp_col => 'ts_col',
1060+
model => 'TimesFM 2.0',
1061+
id_cols => ['id1', 'id2'],
1062+
horizon => 20,
1063+
confidence_level => 0.95
1064+
)
1065+
"""
1066+
mock_execute_sql.assert_called_once_with(
1067+
"test-project",
1068+
expected_query,
1069+
mock_credentials,
1070+
mock_settings,
1071+
mock_tool_context,
1072+
)
1073+
1074+
1075+
# AI.Forecast calls execute_sql with a specific query statement. We need to
1076+
# test that the query is properly constructed and call execute_sql with the
1077+
# correct parameters exactly once.
1078+
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
1079+
def test_forecast_with_query_statement(mock_execute_sql):
1080+
mock_credentials = mock.MagicMock(spec=Credentials)
1081+
mock_settings = BigQueryToolConfig()
1082+
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
1083+
1084+
history_data_query = "SELECT * FROM `test-dataset.test-table`"
1085+
forecast(
1086+
project_id="test-project",
1087+
history_data=history_data_query,
1088+
timestamp_col="ts_col",
1089+
data_col="data_col",
1090+
credentials=mock_credentials,
1091+
settings=mock_settings,
1092+
tool_context=mock_tool_context,
1093+
)
1094+
1095+
expected_query = f"""
1096+
SELECT * FROM AI.FORECAST(
1097+
({history_data_query}),
1098+
data_col => 'data_col',
1099+
timestamp_col => 'ts_col',
1100+
model => 'TimesFM 2.0',
1101+
horizon => 10,
1102+
confidence_level => 0.95
1103+
)
1104+
"""
1105+
mock_execute_sql.assert_called_once_with(
1106+
"test-project",
1107+
expected_query,
1108+
mock_credentials,
1109+
mock_settings,
1110+
mock_tool_context,
1111+
)
1112+
1113+
1114+
def test_forecast_with_invalid_id_cols():
1115+
mock_credentials = mock.MagicMock(spec=Credentials)
1116+
mock_settings = BigQueryToolConfig()
1117+
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
1118+
1119+
result = forecast(
1120+
project_id="test-project",
1121+
history_data="test-dataset.test-table",
1122+
timestamp_col="ts_col",
1123+
data_col="data_col",
1124+
credentials=mock_credentials,
1125+
settings=mock_settings,
1126+
tool_context=mock_tool_context,
1127+
id_cols=["id1", 123],
1128+
)
1129+
1130+
assert result["status"] == "ERROR"
1131+
assert "All elements in id_cols must be strings." in result["error_details"]

tests/unittests/tools/bigquery/test_bigquery_toolset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default():
4141
tools = await toolset.get_tools()
4242
assert tools is not None
4343

44-
assert len(tools) == 6
44+
assert len(tools) == 7
4545
assert all([isinstance(tool, GoogleTool) for tool in tools])
4646

4747
expected_tool_names = set([
@@ -51,6 +51,7 @@ async def test_bigquery_toolset_tools_default():
5151
"get_table_info",
5252
"execute_sql",
5353
"ask_data_insights",
54+
"forecast",
5455
])
5556
actual_tool_names = set([tool.name for tool in tools])
5657
assert actual_tool_names == expected_tool_names

0 commit comments

Comments
 (0)