|
18 | 18 | import json |
19 | 19 | import types |
20 | 20 | from typing import Callable |
| 21 | +from typing import Optional |
21 | 22 |
|
22 | 23 | from google.auth.credentials import Credentials |
23 | 24 | from google.cloud import bigquery |
@@ -596,3 +597,169 @@ def get_execute_sql(settings: BigQueryToolConfig) -> Callable[..., dict]: |
596 | 597 | execute_sql_wrapper.__doc__ = _execute_sql_write_mode.__doc__ |
597 | 598 |
|
598 | 599 | return execute_sql_wrapper |
| 600 | + |
| 601 | + |
| 602 | +def forecast( |
| 603 | + project_id: str, |
| 604 | + history_data: str, |
| 605 | + timestamp_col: str, |
| 606 | + data_col: str, |
| 607 | + credentials: Credentials, |
| 608 | + settings: BigQueryToolConfig, |
| 609 | + tool_context: ToolContext, |
| 610 | + horizon: int = 10, |
| 611 | + id_cols: Optional[list[str]] = None, |
| 612 | +) -> dict: |
| 613 | + """Run a BigQuery AI time series forecast using AI.FORECAST. |
| 614 | +
|
| 615 | + Args: |
| 616 | + project_id (str): The GCP project id in which the query should be |
| 617 | + executed. |
| 618 | + history_data (str): The table id of the BigQuery table containing the |
| 619 | + history time series data or a query statement that select the history |
| 620 | + data. |
| 621 | + timestamp_col (str): The name of the column containing the timestamp for |
| 622 | + each data point. |
| 623 | + data_col (str): The name of the column containing the numerical values to |
| 624 | + be forecasted. |
| 625 | + credentials (Credentials): The credentials to use for the request. |
| 626 | + settings (BigQueryToolConfig): The settings for the tool. |
| 627 | + tool_context (ToolContext): The context for the tool. |
| 628 | + horizon (int, optional): The number of time steps to forecast into the |
| 629 | + future. Defaults to 10. |
| 630 | + id_cols (list, optional): The column names of the id columns to indicate |
| 631 | + each time series when there are multiple time series in the table. All |
| 632 | + elements must be strings. Defaults to None. |
| 633 | +
|
| 634 | + Returns: |
| 635 | + dict: Dictionary representing the result of the forecast. The result |
| 636 | + contains the forecasted values along with prediction intervals. |
| 637 | +
|
| 638 | + Examples: |
| 639 | + Forecast daily sales for the next 7 days based on historical data from |
| 640 | + a BigQuery table: |
| 641 | +
|
| 642 | + >>> forecast( |
| 643 | + ... project_id="my-gcp-project", |
| 644 | + ... history_data="my-dataset.my-sales-table", |
| 645 | + ... timestamp_col="sale_date", |
| 646 | + ... data_col="daily_sales", |
| 647 | + ... horizon=7 |
| 648 | + ... ) |
| 649 | + { |
| 650 | + "status": "SUCCESS", |
| 651 | + "rows": [ |
| 652 | + { |
| 653 | + "forecast_timestamp": "2025-01-08T00:00:00", |
| 654 | + "forecast_value": 12345.67, |
| 655 | + "confidence_level": 0.95, |
| 656 | + "prediction_interval_lower_bound": 11000.0, |
| 657 | + "prediction_interval_upper_bound": 13691.34, |
| 658 | + "ai_forecast_status": "" |
| 659 | + }, |
| 660 | + ... |
| 661 | + ] |
| 662 | + } |
| 663 | +
|
| 664 | + Forecast multiple time series using a SQL query as input: |
| 665 | +
|
| 666 | + >>> history_query = ( |
| 667 | + ... "SELECT unique_id, timestamp, value " |
| 668 | + ... "FROM `my-project.my-dataset.my-timeseries-table` " |
| 669 | + ... "WHERE timestamp > '1980-01-01'" |
| 670 | + ... ) |
| 671 | + >>> forecast( |
| 672 | + ... project_id="my-gcp-project", |
| 673 | + ... history_data=history_query, |
| 674 | + ... timestamp_col="timestamp", |
| 675 | + ... data_col="value", |
| 676 | + ... id_cols=["unique_id"], |
| 677 | + ... horizon=14 |
| 678 | + ... ) |
| 679 | + { |
| 680 | + "status": "SUCCESS", |
| 681 | + "rows": [ |
| 682 | + { |
| 683 | + "unique_id": "T1", |
| 684 | + "forecast_timestamp": "1980-08-28T00:00:00", |
| 685 | + "forecast_value": 1253218.75, |
| 686 | + "confidence_level": 0.95, |
| 687 | + "prediction_interval_lower_bound": 274252.51, |
| 688 | + "prediction_interval_upper_bound": 2232184.99, |
| 689 | + "ai_forecast_status": "" |
| 690 | + }, |
| 691 | + ... |
| 692 | + ] |
| 693 | + } |
| 694 | +
|
| 695 | + Error Scenarios: |
| 696 | + When an element in `id_cols` is not a string: |
| 697 | +
|
| 698 | + >>> forecast( |
| 699 | + ... project_id="my-gcp-project", |
| 700 | + ... history_data="my-dataset.my-sales-table", |
| 701 | + ... timestamp_col="sale_date", |
| 702 | + ... data_col="daily_sales", |
| 703 | + ... id_cols=["store_id", 123] |
| 704 | + ... ) |
| 705 | + { |
| 706 | + "status": "ERROR", |
| 707 | + "error_details": "All elements in id_cols must be strings." |
| 708 | + } |
| 709 | +
|
| 710 | + When `history_data` refers to a table that does not exist: |
| 711 | +
|
| 712 | + >>> forecast( |
| 713 | + ... project_id="my-gcp-project", |
| 714 | + ... history_data="my-dataset.non-existent-table", |
| 715 | + ... timestamp_col="sale_date", |
| 716 | + ... data_col="daily_sales" |
| 717 | + ... ) |
| 718 | + { |
| 719 | + "status": "ERROR", |
| 720 | + "error_details": "Not found: Table |
| 721 | + my-gcp-project:my-dataset.non-existent-table was not found in |
| 722 | + location US" |
| 723 | + } |
| 724 | + """ |
| 725 | + model = "TimesFM 2.0" |
| 726 | + confidence_level = 0.95 |
| 727 | + trimmed_upper_history_data = history_data.strip().upper() |
| 728 | + if trimmed_upper_history_data.startswith( |
| 729 | + "SELECT" |
| 730 | + ) or trimmed_upper_history_data.startswith("WITH"): |
| 731 | + history_data_source = f"({history_data})" |
| 732 | + else: |
| 733 | + history_data_source = f"TABLE `{history_data}`" |
| 734 | + |
| 735 | + if id_cols: |
| 736 | + if not all(isinstance(item, str) for item in id_cols): |
| 737 | + return { |
| 738 | + "status": "ERROR", |
| 739 | + "error_details": "All elements in id_cols must be strings.", |
| 740 | + } |
| 741 | + id_cols_str = "[" + ", ".join([f"'{col}'" for col in id_cols]) + "]" |
| 742 | + |
| 743 | + query = f""" |
| 744 | + SELECT * FROM AI.FORECAST( |
| 745 | + {history_data_source}, |
| 746 | + data_col => '{data_col}', |
| 747 | + timestamp_col => '{timestamp_col}', |
| 748 | + model => '{model}', |
| 749 | + id_cols => {id_cols_str}, |
| 750 | + horizon => {horizon}, |
| 751 | + confidence_level => {confidence_level} |
| 752 | + ) |
| 753 | + """ |
| 754 | + else: |
| 755 | + query = f""" |
| 756 | + SELECT * FROM AI.FORECAST( |
| 757 | + {history_data_source}, |
| 758 | + data_col => '{data_col}', |
| 759 | + timestamp_col => '{timestamp_col}', |
| 760 | + model => '{model}', |
| 761 | + horizon => {horizon}, |
| 762 | + confidence_level => {confidence_level} |
| 763 | + ) |
| 764 | + """ |
| 765 | + return execute_sql(project_id, query, credentials, settings, tool_context) |
0 commit comments