Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
77e4bd8
fix: added proto message conversion to MDMJob.update fields
rosiezou Oct 5, 2022
cca9b0d
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 5, 2022
2ff747a
addressed PR comment
rosiezou Oct 6, 2022
d53f7fc
formatting
rosiezou Oct 6, 2022
07a2c3d
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 6, 2022
0cdfc9b
Merge branch 'mm-fix' of https://github.com/googleapis/python-aiplatf…
gcf-owl-bot[bot] Oct 6, 2022
7348828
Merge branch 'main' into mm-fix
rosiezou Oct 7, 2022
f8c16cc
replaced string literal with constant
rosiezou Oct 7, 2022
d029939
Merge branch 'main' into mm-fix
rosiezou Oct 7, 2022
fb60c39
adding _gca_resource re-assignmnet to mdm job class
rosiezou Oct 11, 2022
cc72e72
Merge branch 'main' into mm-fix
rosiezou Oct 17, 2022
3ebd054
Added side effects in get_mdm_job pytest mock
rosiezou Oct 17, 2022
bdc1bce
fixing side effects
rosiezou Oct 18, 2022
650166e
formatting
rosiezou Oct 18, 2022
fcc0638
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 18, 2022
250734d
Merge branch 'mm-fix' of https://github.com/googleapis/python-aiplatf…
gcf-owl-bot[bot] Oct 18, 2022
6dc5f12
minor edits to variable names
rosiezou Oct 18, 2022
d6d300c
Addressed PR feedback
rosiezou Oct 18, 2022
0b87db7
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Oct 18, 2022
000428b
addressed more PR commentes
rosiezou Oct 19, 2022
ccbfbf1
addressed PR comments
rosiezou Oct 19, 2022
7d52bd8
Merge branch 'main' into mm-fix
rosiezou Oct 19, 2022
de5489d
fix linter errors
rosiezou Oct 19, 2022
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2427,7 +2427,8 @@ def update(
are allowed. See https://goo.gl/xmQnxf for more information
and examples of labels.
bigquery_tables_log_ttl (int):
Optional. The TTL(time to live) of BigQuery tables in user projects
Optional. The number of days for which the logs are stored.
The TTL(time to live) of BigQuery tables in user projects
which stores logs. A day is the basic unit of
the TTL and we take the ceil of TTL/86400(a
day). e.g. { second: 3600} indicates ttl = 1
Expand Down Expand Up @@ -2456,25 +2457,30 @@ def update(
current_job = self.api_client.get_model_deployment_monitoring_job(
name=self._gca_resource.name
)
mdm_job_name = self._gca_resource.name
update_mask: List[str] = []
if display_name is not None:
update_mask.append("display_name")
current_job.display_name = display_name
if schedule_config is not None:
update_mask.append("model_deployment_monitoring_schedule_config")
current_job.model_deployment_monitoring_schedule_config = schedule_config
current_job.model_deployment_monitoring_schedule_config = (
schedule_config.as_proto()
)
if alert_config is not None:
update_mask.append("model_monitoring_alert_config")
current_job.model_monitoring_alert_config = alert_config
current_job.model_monitoring_alert_config = alert_config.as_proto()
if logging_sampling_strategy is not None:
update_mask.append("logging_sampling_strategy")
current_job.logging_sampling_strategy = logging_sampling_strategy
current_job.logging_sampling_strategy = logging_sampling_strategy.as_proto()
if labels is not None:
update_mask.append("labels")
current_job.lables = labels
current_job.labels = labels
if bigquery_tables_log_ttl is not None:
update_mask.append("log_ttl")
current_job.log_ttl = bigquery_tables_log_ttl
current_job.log_ttl = duration_pb2.Duration(
seconds=bigquery_tables_log_ttl * 86400
)
if enable_monitoring_pipeline_logs is not None:
update_mask.append("enable_monitoring_pipeline_logs")
current_job.enable_monitoring_pipeline_logs = (
Expand All @@ -2495,6 +2501,9 @@ def update(
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)
self._gca_resource = self.api_client.get_model_deployment_monitoring_job(
name=mdm_job_name
)
return self

def pause(self) -> "ModelDeploymentMonitoringJob":
Expand Down
142 changes: 105 additions & 37 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import pytest
import copy

from unittest import mock
from importlib import reload
Expand Down Expand Up @@ -46,7 +47,9 @@
job_service_client,
)
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import duration_pb2 # type: ignore

import test_endpoints # noqa: F401
from test_endpoints import get_endpoint_with_models_mock # noqa: F401

_TEST_API_CLIENT = job_service_client.JobServiceClient
Expand Down Expand Up @@ -175,6 +178,56 @@
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"

_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}
_TEST_MDM_USER_EMAIL = "TEST_EMAIL"
_TEST_MDM_SAMPLE_RATE = 0.5
_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"}
_TEST_LOG_TTL_IN_DAYS = 1
_TEST_MDM_NEW_NAME = "NEW_NAME"

_TEST_MDM_OLD_JOB = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
endpoint=_TEST_ENDPOINT,
)
)

_TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_MDM_NEW_NAME,
endpoint=_TEST_ENDPOINT,
model_deployment_monitoring_objective_configs=[
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id=model_id,
objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
drift_thresholds={
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(
value=0.01
)
}
)
),
)
for model_id in [model.id for model in test_endpoints._TEST_DEPLOYED_MODELS]
],
logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy(
random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(
sample_rate=_TEST_MDM_SAMPLE_RATE
)
),
labels=_TEST_MDM_LABEL,
model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig(
email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=[_TEST_MDM_USER_EMAIL]
)
),
model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig(
monitor_interval=duration_pb2.Duration(seconds=3600)
),
log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400),
enable_monitoring_pipeline_logs=True,
)

# TODO(b/171333554): Move reusable test fixtures to conftest.py file

Expand Down Expand Up @@ -988,48 +1041,22 @@ def get_mdm_job_mock():
with mock.patch.object(
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
) as get_mdm_job_mock:
get_mdm_job_mock.return_value = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
)
)
get_mdm_job_mock.side_effect = [
_TEST_MDM_OLD_JOB,
_TEST_MDM_OLD_JOB,
_TEST_MDM_OLD_JOB,
_TEST_MDM_OLD_JOB,
_TEST_MDM_EXPECTED_NEW_JOB,
]
yield get_mdm_job_mock


@pytest.fixture
@pytest.mark.usefixtures("get_mdm_job_mock")
def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811
with mock.patch.object(
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
) as update_mdm_job_mock:
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
drift_thresholds={
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
}
)
)
all_configs = []
for model in get_endpoint_with_models_mock.return_value.deployed_models:
all_configs.append(
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id=model.id,
objective_config=expected_objective_config,
)
)

update_mdm_job_mock.return_vaue.result_type = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
model_deployment_monitoring_objective_configs=all_configs,
)
)
update_mdm_job_mock.return_value.result_type = _TEST_MDM_EXPECTED_NEW_JOB
yield update_mdm_job_mock


Expand All @@ -1046,13 +1073,45 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
job = jobs.ModelDeploymentMonitoringJob(
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
)
old_job = copy.deepcopy(job._gca_resource)
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
)
schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1)
alert_config = aiplatform.model_monitoring.EmailAlertConfig(
user_emails=[_TEST_MDM_USER_EMAIL]
)
sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig(
sample_rate=_TEST_MDM_SAMPLE_RATE
)
labels = _TEST_MDM_LABEL
log_ttl = _TEST_LOG_TTL_IN_DAYS
display_name = _TEST_MDM_NEW_NAME
new_config = aiplatform.model_monitoring.ObjectiveConfig(
drift_detection_config=drift_detection_config
)
job.update(objective_configs=new_config)
job.update(
display_name=display_name,
schedule_config=schedule_config,
alert_config=alert_config,
logging_sampling_strategy=sampling_strategy,
labels=labels,
bigquery_tables_log_ttl=log_ttl,
enable_monitoring_pipeline_logs=True,
objective_configs=new_config,
)
new_job = job._gca_resource
assert old_job != new_job
assert new_job.display_name == display_name
assert new_job.logging_sampling_strategy == sampling_strategy.as_proto()
assert (
new_job.model_deployment_monitoring_schedule_config
== schedule_config.as_proto()
)
assert new_job.labels == labels
assert new_job.model_monitoring_alert_config == alert_config.as_proto()
assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS
assert new_job.enable_monitoring_pipeline_logs
assert (
job._gca_resource.model_deployment_monitoring_objective_configs[
0
Expand All @@ -1063,8 +1122,17 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
name=_TEST_MDM_JOB_NAME,
)
update_mdm_job_mock.assert_called_once_with(
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
model_deployment_monitoring_job=new_job,
update_mask=field_mask_pb2.FieldMask(
paths=["model_deployment_monitoring_objective_configs"]
paths=[
"display_name",
"model_deployment_monitoring_schedule_config",
"model_monitoring_alert_config",
"logging_sampling_strategy",
"labels",
"log_ttl",
"enable_monitoring_pipeline_logs",
"model_deployment_monitoring_objective_configs",
]
),
)