Skip to content

Commit 3e6a8f7

Browse files
author
Helin Wang
committed
Add support for feature importance.
Previously feature importance is added with a params parameter, the user has to set params = {"feature_importance": "true"}. This PR simplifies the logic, the user just have to pass feature_importance = True.
1 parent 972e5b4 commit 3e6a8f7

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

automl/google/cloud/automl_v1beta1/tables/tables_client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,7 +2596,7 @@ def predict(
25962596
model=None,
25972597
model_name=None,
25982598
model_display_name=None,
2599-
params=None,
2599+
feature_importance=False,
26002600
project=None,
26012601
region=None,
26022602
**kwargs
@@ -2643,9 +2643,9 @@ def predict(
26432643
The `model` instance you want to predict with . This must be
26442644
supplied if `model_display_name` or `model_name` are not
26452645
supplied.
2646-
params (dict[str, str]):
2647-
`feature_importance` can be set as True to enable local
2648-
explainability. The default is false.
2646+
feature_importance (bool):
2647+
True if enable feature importance explainability. The default is
2648+
False.
26492649
26502650
Returns:
26512651
A :class:`~google.cloud.automl_v1beta1.types.PredictResponse`
@@ -2687,6 +2687,10 @@ def predict(
26872687

26882688
request = {"row": {"values": values}}
26892689

2690+
params = None
2691+
if feature_importance:
2692+
params = {"feature_importance": "true"}
2693+
26902694
return self.prediction_client.predict(model.name, request, params, **kwargs)
26912695

26922696
def batch_predict(

automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,25 @@ def test_predict_from_dict(self):
11371137
None,
11381138
)
11391139

1140+
def test_predict_from_dict_with_feature_importance(self):
1141+
data_type = mock.Mock(type_code=data_types_pb2.CATEGORY)
1142+
column_spec_a = mock.Mock(display_name="a", data_type=data_type)
1143+
column_spec_b = mock.Mock(display_name="b", data_type=data_type)
1144+
model_metadata = mock.Mock(
1145+
input_feature_column_specs=[column_spec_a, column_spec_b]
1146+
)
1147+
model = mock.Mock()
1148+
model.configure_mock(tables_model_metadata=model_metadata, name="my_model")
1149+
client = self.tables_client({"get_model.return_value": model}, {})
1150+
client.predict(
1151+
{"a": "1", "b": "2"}, model_name="my_model", feature_importance=True
1152+
)
1153+
client.prediction_client.predict.assert_called_with(
1154+
"my_model",
1155+
{"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}},
1156+
{"feature_importance": "true"},
1157+
)
1158+
11401159
def test_predict_from_dict_missing(self):
11411160
data_type = mock.Mock(type_code=data_types_pb2.CATEGORY)
11421161
column_spec_a = mock.Mock(display_name="a", data_type=data_type)

0 commit comments

Comments
 (0)