Skip to content

Commit 6720044

Browse files
committed
fix: default model_display_name to _CustomTrainingJob.display_name when model_serving_container_image_uri is provided
1 parent c2caaa6 commit 6720044

File tree

2 files changed

+161
-1
lines changed

2 files changed

+161
-1
lines changed

google/cloud/aiplatform/training_jobs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,8 @@ def _prepare_and_validate_run(
14551455
If the script produces a managed AI Platform Model. The display name of
14561456
the Model. The name can be up to 128 characters long and can be consist
14571457
of any UTF-8 characters.
1458+
1459+
If not provided upon creation, the job's display_name is used.
14581460
replica_count (int):
14591461
The number of worker replicas. If replica count = 1 then one chief
14601462
replica will be provisioned. If replica_count > 1 the remainder will be
@@ -1491,6 +1493,9 @@ def _prepare_and_validate_run(
14911493
"""
14921494
)
14931495

1496+
if self._managed_model.container_spec.image_uri:
1497+
model_display_name = model_display_name or self._display_name + "-model"
1498+
14941499
# validates args and will raise
14951500
worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool(
14961501
replica_count=replica_count,
@@ -1854,6 +1859,8 @@ def run(
18541859
If the script produces a managed AI Platform Model. The display name of
18551860
the Model. The name can be up to 128 characters long and can be consist
18561861
of any UTF-8 characters.
1862+
1863+
If not provided upon creation, the job's display_name is used.
18571864
base_output_dir (str):
18581865
GCS output directory of job. If not provided a
18591866
timestamped directory in the staging directory will be used.
@@ -2371,6 +2378,8 @@ def run(
23712378
If the script produces a managed AI Platform Model. The display name of
23722379
the Model. The name can be up to 128 characters long and can be consist
23732380
of any UTF-8 characters.
2381+
2382+
If not provided upon creation, the job's display_name is used.
23742383
base_output_dir (str):
23752384
GCS output directory of job. If not provided a
23762385
timestamped directory in the staging directory will be used.
@@ -3636,6 +3645,8 @@ def run(
36363645
If the script produces a managed AI Platform Model. The display name of
36373646
the Model. The name can be up to 128 characters long and can be consist
36383647
of any UTF-8 characters.
3648+
3649+
If not provided upon creation, the job's display_name is used.
36393650
base_output_dir (str):
36403651
GCS output directory of job. If not provided a
36413652
timestamped directory in the staging directory will be used.

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def teardown_method(self):
557557
initializer.global_pool.shutdown(wait=True)
558558

559559
@pytest.mark.parametrize("sync", [True, False])
560-
def test_run_call_pipeline_service_create_with_tabular_dataset(
560+
def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name(
561561
self,
562562
mock_pipeline_service_create,
563563
mock_pipeline_service_get,
@@ -713,6 +713,155 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
713713

714714
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
715715

716+
def test_run_call_pipeline_service_create_with_tabular_dataset(
717+
self,
718+
mock_pipeline_service_create,
719+
mock_python_package_to_gcs,
720+
mock_tabular_dataset,
721+
mock_model_service_get,
722+
):
723+
aiplatform.init(
724+
project=_TEST_PROJECT,
725+
staging_bucket=_TEST_BUCKET_NAME,
726+
credentials=_TEST_CREDENTIALS,
727+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
728+
)
729+
730+
job = training_jobs.CustomTrainingJob(
731+
display_name=_TEST_DISPLAY_NAME,
732+
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
733+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
734+
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
735+
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
736+
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
737+
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
738+
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
739+
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
740+
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
741+
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
742+
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
743+
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
744+
model_description=_TEST_MODEL_DESCRIPTION,
745+
)
746+
747+
model_from_job = job.run(
748+
dataset=mock_tabular_dataset,
749+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
750+
args=_TEST_RUN_ARGS,
751+
replica_count=1,
752+
machine_type=_TEST_MACHINE_TYPE,
753+
accelerator_type=_TEST_ACCELERATOR_TYPE,
754+
accelerator_count=_TEST_ACCELERATOR_COUNT,
755+
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
756+
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
757+
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
758+
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
759+
)
760+
761+
mock_python_package_to_gcs.assert_called_once_with(
762+
gcs_staging_dir=_TEST_BUCKET_NAME,
763+
project=_TEST_PROJECT,
764+
credentials=initializer.global_config.credentials,
765+
)
766+
767+
true_args = _TEST_RUN_ARGS
768+
769+
true_worker_pool_spec = {
770+
"replicaCount": _TEST_REPLICA_COUNT,
771+
"machineSpec": {
772+
"machineType": _TEST_MACHINE_TYPE,
773+
"acceleratorType": _TEST_ACCELERATOR_TYPE,
774+
"acceleratorCount": _TEST_ACCELERATOR_COUNT,
775+
},
776+
"pythonPackageSpec": {
777+
"executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE,
778+
"pythonModule": training_jobs._TrainingScriptPythonPackager.module_name,
779+
"packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
780+
"args": true_args,
781+
},
782+
}
783+
784+
true_fraction_split = gca_training_pipeline.FractionSplit(
785+
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
786+
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
787+
test_fraction=_TEST_TEST_FRACTION_SPLIT,
788+
)
789+
790+
env = [
791+
gca_env_var.EnvVar(name=str(key), value=str(value))
792+
for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items()
793+
]
794+
795+
ports = [
796+
gca_model.Port(container_port=port)
797+
for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
798+
]
799+
800+
true_container_spec = gca_model.ModelContainerSpec(
801+
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
802+
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
803+
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
804+
command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
805+
args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
806+
env=env,
807+
ports=ports,
808+
)
809+
810+
true_managed_model = gca_model.Model(
811+
display_name=_TEST_DISPLAY_NAME,
812+
description=_TEST_MODEL_DESCRIPTION,
813+
container_spec=true_container_spec,
814+
predict_schemata=gca_model.PredictSchemata(
815+
instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
816+
parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
817+
prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
818+
),
819+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
820+
)
821+
822+
true_input_data_config = gca_training_pipeline.InputDataConfig(
823+
fraction_split=true_fraction_split,
824+
predefined_split=gca_training_pipeline.PredefinedSplit(
825+
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
826+
),
827+
dataset_id=mock_tabular_dataset.name,
828+
gcs_destination=gca_io.GcsDestination(
829+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
830+
),
831+
)
832+
833+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
834+
display_name=_TEST_DISPLAY_NAME,
835+
training_task_definition=schema.training_job.definition.custom_task,
836+
training_task_inputs=json_format.ParseDict(
837+
{
838+
"workerPoolSpecs": [true_worker_pool_spec],
839+
"baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR},
840+
},
841+
struct_pb2.Value(),
842+
),
843+
model_to_upload=true_managed_model,
844+
input_data_config=true_input_data_config,
845+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
846+
)
847+
848+
mock_pipeline_service_create.assert_called_once_with(
849+
parent=initializer.global_config.common_location_path(),
850+
training_pipeline=true_training_pipeline,
851+
)
852+
853+
assert job._gca_resource is mock_pipeline_service_create.return_value
854+
855+
mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME)
856+
857+
assert model_from_job._gca_resource is mock_model_service_get.return_value
858+
859+
assert job.get_model()._gca_resource is mock_model_service_get.return_value
860+
861+
assert not job.has_failed
862+
863+
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
864+
716865
@pytest.mark.parametrize("sync", [True, False])
717866
def test_run_call_pipeline_service_create_with_bigquery_destination(
718867
self,

0 commit comments

Comments
 (0)