@@ -557,7 +557,7 @@ def teardown_method(self):
557557 initializer .global_pool .shutdown (wait = True )
558558
559559 @pytest .mark .parametrize ("sync" , [True , False ])
560- def test_run_call_pipeline_service_create_with_tabular_dataset (
560+ def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name (
561561 self ,
562562 mock_pipeline_service_create ,
563563 mock_pipeline_service_get ,
@@ -713,6 +713,155 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
713713
714714 assert job .state == gca_pipeline_state .PipelineState .PIPELINE_STATE_SUCCEEDED
715715
716+ def test_run_call_pipeline_service_create_with_tabular_dataset (
717+ self ,
718+ mock_pipeline_service_create ,
719+ mock_python_package_to_gcs ,
720+ mock_tabular_dataset ,
721+ mock_model_service_get ,
722+ ):
723+ aiplatform .init (
724+ project = _TEST_PROJECT ,
725+ staging_bucket = _TEST_BUCKET_NAME ,
726+ credentials = _TEST_CREDENTIALS ,
727+ encryption_spec_key_name = _TEST_DEFAULT_ENCRYPTION_KEY_NAME ,
728+ )
729+
730+ job = training_jobs .CustomTrainingJob (
731+ display_name = _TEST_DISPLAY_NAME ,
732+ script_path = _TEST_LOCAL_SCRIPT_FILE_NAME ,
733+ container_uri = _TEST_TRAINING_CONTAINER_IMAGE ,
734+ model_serving_container_image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
735+ model_serving_container_predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
736+ model_serving_container_health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
737+ model_instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
738+ model_parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
739+ model_prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
740+ model_serving_container_command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
741+ model_serving_container_args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
742+ model_serving_container_environment_variables = _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES ,
743+ model_serving_container_ports = _TEST_MODEL_SERVING_CONTAINER_PORTS ,
744+ model_description = _TEST_MODEL_DESCRIPTION ,
745+ )
746+
747+ model_from_job = job .run (
748+ dataset = mock_tabular_dataset ,
749+ base_output_dir = _TEST_BASE_OUTPUT_DIR ,
750+ args = _TEST_RUN_ARGS ,
751+ replica_count = 1 ,
752+ machine_type = _TEST_MACHINE_TYPE ,
753+ accelerator_type = _TEST_ACCELERATOR_TYPE ,
754+ accelerator_count = _TEST_ACCELERATOR_COUNT ,
755+ training_fraction_split = _TEST_TRAINING_FRACTION_SPLIT ,
756+ validation_fraction_split = _TEST_VALIDATION_FRACTION_SPLIT ,
757+ test_fraction_split = _TEST_TEST_FRACTION_SPLIT ,
758+ predefined_split_column_name = _TEST_PREDEFINED_SPLIT_COLUMN_NAME ,
759+ )
760+
761+ mock_python_package_to_gcs .assert_called_once_with (
762+ gcs_staging_dir = _TEST_BUCKET_NAME ,
763+ project = _TEST_PROJECT ,
764+ credentials = initializer .global_config .credentials ,
765+ )
766+
767+ true_args = _TEST_RUN_ARGS
768+
769+ true_worker_pool_spec = {
770+ "replicaCount" : _TEST_REPLICA_COUNT ,
771+ "machineSpec" : {
772+ "machineType" : _TEST_MACHINE_TYPE ,
773+ "acceleratorType" : _TEST_ACCELERATOR_TYPE ,
774+ "acceleratorCount" : _TEST_ACCELERATOR_COUNT ,
775+ },
776+ "pythonPackageSpec" : {
777+ "executorImageUri" : _TEST_TRAINING_CONTAINER_IMAGE ,
778+ "pythonModule" : training_jobs ._TrainingScriptPythonPackager .module_name ,
779+ "packageUris" : [_TEST_OUTPUT_PYTHON_PACKAGE_PATH ],
780+ "args" : true_args ,
781+ },
782+ }
783+
784+ true_fraction_split = gca_training_pipeline .FractionSplit (
785+ training_fraction = _TEST_TRAINING_FRACTION_SPLIT ,
786+ validation_fraction = _TEST_VALIDATION_FRACTION_SPLIT ,
787+ test_fraction = _TEST_TEST_FRACTION_SPLIT ,
788+ )
789+
790+ env = [
791+ gca_env_var .EnvVar (name = str (key ), value = str (value ))
792+ for key , value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES .items ()
793+ ]
794+
795+ ports = [
796+ gca_model .Port (container_port = port )
797+ for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
798+ ]
799+
800+ true_container_spec = gca_model .ModelContainerSpec (
801+ image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
802+ predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
803+ health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
804+ command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
805+ args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
806+ env = env ,
807+ ports = ports ,
808+ )
809+
810+ true_managed_model = gca_model .Model (
811+ display_name = _TEST_DISPLAY_NAME ,
812+ description = _TEST_MODEL_DESCRIPTION ,
813+ container_spec = true_container_spec ,
814+ predict_schemata = gca_model .PredictSchemata (
815+ instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
816+ parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
817+ prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
818+ ),
819+ encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
820+ )
821+
822+ true_input_data_config = gca_training_pipeline .InputDataConfig (
823+ fraction_split = true_fraction_split ,
824+ predefined_split = gca_training_pipeline .PredefinedSplit (
825+ key = _TEST_PREDEFINED_SPLIT_COLUMN_NAME
826+ ),
827+ dataset_id = mock_tabular_dataset .name ,
828+ gcs_destination = gca_io .GcsDestination (
829+ output_uri_prefix = _TEST_BASE_OUTPUT_DIR
830+ ),
831+ )
832+
833+ true_training_pipeline = gca_training_pipeline .TrainingPipeline (
834+ display_name = _TEST_DISPLAY_NAME ,
835+ training_task_definition = schema .training_job .definition .custom_task ,
836+ training_task_inputs = json_format .ParseDict (
837+ {
838+ "workerPoolSpecs" : [true_worker_pool_spec ],
839+ "baseOutputDirectory" : {"output_uri_prefix" : _TEST_BASE_OUTPUT_DIR },
840+ },
841+ struct_pb2 .Value (),
842+ ),
843+ model_to_upload = true_managed_model ,
844+ input_data_config = true_input_data_config ,
845+ encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
846+ )
847+
848+ mock_pipeline_service_create .assert_called_once_with (
849+ parent = initializer .global_config .common_location_path (),
850+ training_pipeline = true_training_pipeline ,
851+ )
852+
853+ assert job ._gca_resource is mock_pipeline_service_create .return_value
854+
855+ mock_model_service_get .assert_called_once_with (name = _TEST_MODEL_NAME )
856+
857+ assert model_from_job ._gca_resource is mock_model_service_get .return_value
858+
859+ assert job .get_model ()._gca_resource is mock_model_service_get .return_value
860+
861+ assert not job .has_failed
862+
863+ assert job .state == gca_pipeline_state .PipelineState .PIPELINE_STATE_SUCCEEDED
864+
716865 @pytest .mark .parametrize ("sync" , [True , False ])
717866 def test_run_call_pipeline_service_create_with_bigquery_destination (
718867 self ,
0 commit comments