1818import datetime
1919import time
2020from typing import Dict , List , Optional , Sequence , Tuple , Union
21+ import warnings
2122
2223import abc
2324
@@ -2525,6 +2526,7 @@ def __init__(
25252526 display_name : str ,
25262527 optimization_prediction_type : str ,
25272528 optimization_objective : Optional [str ] = None ,
2529+ column_specs : Optional [Dict [str , str ]] = None ,
25282530 column_transformations : Optional [Union [Dict , List [Dict ]]] = None ,
25292531 optimization_objective_recall_value : Optional [float ] = None ,
25302532 optimization_objective_precision_value : Optional [float ] = None ,
@@ -2536,6 +2538,15 @@ def __init__(
25362538 ):
25372539 """Constructs a AutoML Tabular Training Job.
25382540
2541+ Example usage:
2542+
2543+ job = training_jobs.AutoMLTabularTrainingJob(
2544+ display_name="my_display_name",
2545+ optimization_prediction_type="classification",
2546+ optimization_objective="minimize-log-loss",
2547+ column_specs={"column_1": "auto", "column_2": "numeric"},
2548+ )
2549+
25392550 Args:
25402551 display_name (str):
25412552 Required. The user-defined name of this TrainingPipeline.
@@ -2576,15 +2587,29 @@ def __init__(
25762587 "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
25772588 "minimize-mae" - Minimize mean-absolute error (MAE).
25782589 "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
2579- column_transformations (Optional[Union[Dict, List[Dict]]]):
2590+ column_specs (Dict[str, str]):
2591+ Optional. Alternative to column_transformations where the keys of the dict
2592+ are column names and their respective values are one of
2593+ AutoMLTabularTrainingJob.column_data_types.
2594+ When creating transformation for BigQuery Struct column, the column
2595+ should be flattened using "." as the delimiter. Only columns with no child
2596+ should have a transformation.
2597+ If an input column has no transformations on it, such a column is
2598+ ignored by the training, except for the targetColumn, which should have
2599+ no transformations defined on.
2600+ Only one of column_transformations or column_specs should be passed.
2601+ column_transformations (Union[Dict, List[Dict]]):
25802602 Optional. Transformations to apply to the input columns (i.e. columns other
25812603 than the targetColumn). Each transformation may produce multiple
25822604 result values from the column's value, and all are used for training.
25832605 When creating transformation for BigQuery Struct column, the column
2584- should be flattened using "." as the delimiter.
2606+ should be flattened using "." as the delimiter. Only columns with no child
2607+ should have a transformation.
25852608 If an input column has no transformations on it, such a column is
25862609 ignored by the training, except for the targetColumn, which should have
25872610 no transformations defined on.
2611+ Only one of column_transformations or column_specs should be passed.
2612+ Consider using column_specs as column_transformations will be deprecated eventually.
25882613 optimization_objective_recall_value (float):
25892614 Optional. Required when maximize-precision-at-recall optimizationObjective was
25902615 picked, represents the recall value at which the optimization is done.
@@ -2628,6 +2653,9 @@ def __init__(
26282653 If set, the trained Model will be secured by this key.
26292654
26302655 Overrides encryption_spec_key_name set in aiplatform.init.
2656+
2657+ Raises:
2658+ ValueError: When both column_transforations and column_specs were passed
26312659 """
26322660 super ().__init__ (
26332661 display_name = display_name ,
@@ -2637,7 +2665,26 @@ def __init__(
26372665 training_encryption_spec_key_name = training_encryption_spec_key_name ,
26382666 model_encryption_spec_key_name = model_encryption_spec_key_name ,
26392667 )
2640- self ._column_transformations = column_transformations
2668+ # user populated transformations
2669+ if column_transformations is not None and column_specs is not None :
2670+ raise ValueError (
2671+ "Both column_transformations and column_specs were passed. Only one is allowed."
2672+ )
2673+ if column_transformations is not None :
2674+ self ._column_transformations = column_transformations
2675+ warnings .simplefilter ("always" , DeprecationWarning )
2676+ warnings .warn (
2677+ "consider using column_specs instead. column_transformations will be deprecated in the future." ,
2678+ DeprecationWarning ,
2679+ stacklevel = 2 ,
2680+ )
2681+ elif column_specs is not None :
2682+ self ._column_transformations = [
2683+ {transformation : {"column_name" : column_name }}
2684+ for column_name , transformation in column_specs .items ()
2685+ ]
2686+ else :
2687+ self ._column_transformations = None
26412688 self ._optimization_objective = optimization_objective
26422689 self ._optimization_prediction_type = optimization_prediction_type
26432690 self ._optimization_objective_recall_value = optimization_objective_recall_value
@@ -2860,6 +2907,7 @@ def _run(
28602907
28612908 training_task_definition = schema .training_job .definition .automl_tabular
28622909
2910+ # auto-populate transformations
28632911 if self ._column_transformations is None :
28642912 _LOGGER .info (
28652913 "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
@@ -2870,21 +2918,19 @@ def _run(
28702918 for column_name in dataset .column_names
28712919 if column_name != target_column
28722920 ]
2873- column_transformations = [
2921+ self . _column_transformations = [
28742922 {"auto" : {"column_name" : column_name }} for column_name in column_names
28752923 ]
28762924
28772925 _LOGGER .info (
28782926 "The column transformation of type 'auto' was set for the following columns: %s."
28792927 % column_names
28802928 )
2881- else :
2882- column_transformations = self ._column_transformations
28832929
28842930 training_task_inputs_dict = {
28852931 # required inputs
28862932 "targetColumn" : target_column ,
2887- "transformations" : column_transformations ,
2933+ "transformations" : self . _column_transformations ,
28882934 "trainBudgetMilliNodeHours" : budget_milli_node_hours ,
28892935 # optional inputs
28902936 "weightColumnName" : weight_column ,
@@ -2935,6 +2981,44 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
29352981 """
29362982 self ._additional_experiments .extend (additional_experiments )
29372983
2984+ @staticmethod
2985+ def get_auto_column_specs (
2986+ dataset : datasets .TabularDataset , target_column : str ,
2987+ ) -> Dict [str , str ]:
2988+ """Returns a dict with all non-target columns as keys and 'auto' as values.
2989+
2990+ Example usage:
2991+
2992+ column_specs = training_jobs.AutoMLTabularTrainingJob.get_auto_column_specs(
2993+ dataset=my_dataset,
2994+ target_column="my_target_column",
2995+ )
2996+
2997+ Args:
2998+ dataset (datasets.TabularDataset):
2999+ Required. Intended dataset.
3000+ target_column(str):
3001+ Required. Intended target column.
3002+ Returns:
3003+ Dict[str, str]
3004+ Column names as keys and 'auto' as values
3005+ """
3006+ column_names = [
3007+ column for column in dataset .column_names if column != target_column
3008+ ]
3009+ column_specs = {column : "auto" for column in column_names }
3010+ return column_specs
3011+
3012+ class column_data_types :
3013+ AUTO = "auto"
3014+ NUMERIC = "numeric"
3015+ CATEGORICAL = "categorical"
3016+ TIMESTAMP = "timestamp"
3017+ TEXT = "text"
3018+ REPEATED_NUMERIC = "repeated_numeric"
3019+ REPEATED_CATEGORICAL = "repeated_categorical"
3020+ REPEATED_TEXT = "repeated_text"
3021+
29383022
29393023class AutoMLForecastingTrainingJob (_TrainingJob ):
29403024 _supported_training_schemas = (schema .training_job .definition .automl_forecasting ,)
0 commit comments