5353
5454_TEST_PIPELINE_JOB_NAME = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /pipelineJobs/{ _TEST_PIPELINE_JOB_ID } "
5555
56- _TEST_PIPELINE_PARAMETER_VALUES = {"string_param" : "hello" }
56+ _TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param" : "hello" }
57+ _TEST_PIPELINE_PARAMETER_VALUES = {
58+ "string_param" : "hello world" ,
59+ "bool_param" : True ,
60+ "double_param" : 12.34 ,
61+ "int_param" : 5678 ,
62+ "list_int_param" : [123 , 456 , 789 ],
63+ "list_string_param" : ["lorem" , "ipsum" ],
64+ "struct_param" : {"key1" : 12345 , "key2" : 67890 },
65+ }
66+
5767_TEST_PIPELINE_SPEC_LEGACY = {
5868 "pipelineInfo" : {"name" : "my-pipeline" },
5969 "root" : {
6070 "dag" : {"tasks" : {}},
6171 "inputDefinitions" : {"parameters" : {"string_param" : {"type" : "STRING" }}},
6272 },
63- "schema_version " : "2.0.0" ,
73+ "schemaVersion " : "2.0.0" ,
6474 "components" : {},
6575}
6676_TEST_PIPELINE_SPEC = {
6777 "pipelineInfo" : {"name" : "my-pipeline" },
6878 "root" : {
6979 "dag" : {"tasks" : {}},
7080 "inputDefinitions" : {
71- "parameter_values " : {
81+ "parameterValues " : {
7282 "string_param" : {"parameterType" : "STRING" },
73- "bool_param" : {
74- "parameterType" : "BOOLEAN"
75- },
76- "double_param" : {
77- "parameterType" : "NUMBER_DOUBLE"
78- },
79- "int_param" : {
80- "parameterType" : "NUMBER_INTEGER"
81- },
82- "list_int_param" : {
83- "parameterType" : "LIST"
84- },
85- "list_string_param" : {
86- "parameterType" : "LIST"
87- },
88- "struct_param" : {
89- "parameterType" : "STRUCT"
90- }
83+ "bool_param" : {"parameterType" : "BOOLEAN" },
84+ "double_param" : {"parameterType" : "NUMBER_DOUBLE" },
85+ "int_param" : {"parameterType" : "NUMBER_INTEGER" },
86+ "list_int_param" : {"parameterType" : "LIST" },
87+ "list_string_param" : {"parameterType" : "LIST" },
88+ "struct_param" : {"parameterType" : "STRUCT" },
9189 }
9290 },
9391 },
94- "schema_version" : "2.1.0" ,
92+ "schemaVersion" : "2.1.0" ,
9593 "components" : {},
9694}
9795
9896_TEST_PIPELINE_JOB_LEGACY = {
9997 "runtimeConfig" : {},
10098 "pipelineSpec" : _TEST_PIPELINE_SPEC_LEGACY ,
10199}
102-
103100_TEST_PIPELINE_JOB = {
104101 "runtimeConfig" : {
105102 "parameterValues" : {
109106 "int_param" : 5678 ,
110107 "list_int_param" : [123 , 456 , 789 ],
111108 "list_string_param" : ["lorem" , "ipsum" ],
112- "struct_param" : { "key1" : 12345 , "key2" : 67890 }
109+ "struct_param" : {"key1" : 12345 , "key2" : 67890 },
113110 },
114111 },
115112 "pipelineSpec" : _TEST_PIPELINE_SPEC ,
@@ -250,13 +247,7 @@ def teardown_method(self):
250247 initializer .global_pool .shutdown (wait = True )
251248
252249 @pytest .mark .parametrize (
253- "job_spec_json" ,
254- [
255- _TEST_PIPELINE_SPEC ,
256- _TEST_PIPELINE_JOB ,
257- _TEST_PIPELINE_SPEC_LEGACY ,
258- _TEST_PIPELINE_JOB_LEGACY ,
259- ],
250+ "job_spec_json" , [_TEST_PIPELINE_SPEC , _TEST_PIPELINE_JOB ],
260251 )
261252 @pytest .mark .parametrize ("sync" , [True , False ])
262253 def test_run_call_pipeline_service_create (
@@ -291,7 +282,15 @@ def test_run_call_pipeline_service_create(
291282
292283 expected_runtime_config_dict = {
293284 "gcsOutputDirectory" : _TEST_GCS_BUCKET_NAME ,
294- "parameter_values" : {"string_param" : {"stringValue" : "hello" }},
285+ "parameterValues" : {
286+ "bool_param" : {"boolValue" : True },
287+ "double_param" : {"numberValue" : 12.34 },
288+ "int_param" : {"numberValue" : 5678 },
289+ "list_int_param" : {"listValue" : [123 , 456 , 789 ]},
290+ "list_string_param" : {"listValue" : ["lorem" , "ipsum" ]},
291+ "struct_param" : {"structValue" : {"key1" : 12345 , "key2" : 67890 }},
292+ "string_param" : {"stringValue" : "hello world" },
293+ },
295294 }
296295 runtime_config = gca_pipeline_job_v1 .PipelineJob .RuntimeConfig ()._pb
297296 json_format .ParseDict (expected_runtime_config_dict , runtime_config )
@@ -305,6 +304,7 @@ def test_run_call_pipeline_service_create(
305304 "components" : {},
306305 "pipelineInfo" : pipeline_spec ["pipelineInfo" ],
307306 "root" : pipeline_spec ["root" ],
307+ "schemaVersion" : "2.1.0" ,
308308 },
309309 runtime_config = runtime_config ,
310310 service_account = _TEST_SERVICE_ACCOUNT ,
@@ -326,13 +326,78 @@ def test_run_call_pipeline_service_create(
326326 )
327327
328328 @pytest .mark .parametrize (
329- "job_spec_json" ,
330- [
331- _TEST_PIPELINE_SPEC ,
332- _TEST_PIPELINE_JOB ,
333- _TEST_PIPELINE_SPEC_LEGACY ,
334- _TEST_PIPELINE_JOB_LEGACY ,
335- ],
329+ "job_spec_json" , [_TEST_PIPELINE_SPEC_LEGACY , _TEST_PIPELINE_JOB_LEGACY ],
330+ )
331+ @pytest .mark .parametrize ("sync" , [True , False ])
332+ def test_run_call_pipeline_service_create_legacy (
333+ self ,
334+ mock_pipeline_service_create ,
335+ mock_pipeline_service_get ,
336+ job_spec_json ,
337+ mock_load_json ,
338+ sync ,
339+ ):
340+ aiplatform .init (
341+ project = _TEST_PROJECT ,
342+ staging_bucket = _TEST_GCS_BUCKET_NAME ,
343+ location = _TEST_LOCATION ,
344+ credentials = _TEST_CREDENTIALS ,
345+ )
346+
347+ job = pipeline_jobs .PipelineJob (
348+ display_name = _TEST_PIPELINE_JOB_DISPLAY_NAME ,
349+ template_path = _TEST_TEMPLATE_PATH ,
350+ job_id = _TEST_PIPELINE_JOB_ID ,
351+ parameter_values = _TEST_PIPELINE_PARAMETER_VALUES_LEGACY ,
352+ enable_caching = True ,
353+ )
354+
355+ job .run (
356+ service_account = _TEST_SERVICE_ACCOUNT , network = _TEST_NETWORK , sync = sync ,
357+ )
358+
359+ if not sync :
360+ job .wait ()
361+
362+ expected_runtime_config_dict = {
363+ "gcsOutputDirectory" : _TEST_GCS_BUCKET_NAME ,
364+ "parameters" : {"string_param" : {"stringValue" : "hello" }},
365+ }
366+ runtime_config = gca_pipeline_job_v1 .PipelineJob .RuntimeConfig ()._pb
367+ json_format .ParseDict (expected_runtime_config_dict , runtime_config )
368+
369+ pipeline_spec = job_spec_json .get ("pipelineSpec" ) or job_spec_json
370+
371+ # Construct expected request
372+ expected_gapic_pipeline_job = gca_pipeline_job_v1 .PipelineJob (
373+ display_name = _TEST_PIPELINE_JOB_DISPLAY_NAME ,
374+ pipeline_spec = {
375+ "components" : {},
376+ "pipelineInfo" : pipeline_spec ["pipelineInfo" ],
377+ "root" : pipeline_spec ["root" ],
378+ "schemaVersion" : "2.0.0" ,
379+ },
380+ runtime_config = runtime_config ,
381+ service_account = _TEST_SERVICE_ACCOUNT ,
382+ network = _TEST_NETWORK ,
383+ )
384+
385+ mock_pipeline_service_create .assert_called_once_with (
386+ parent = _TEST_PARENT ,
387+ pipeline_job = expected_gapic_pipeline_job ,
388+ pipeline_job_id = _TEST_PIPELINE_JOB_ID ,
389+ )
390+
391+ mock_pipeline_service_get .assert_called_with (
392+ name = _TEST_PIPELINE_JOB_NAME , retry = base ._DEFAULT_RETRY
393+ )
394+
395+ assert job ._gca_resource == make_pipeline_job (
396+ gca_pipeline_state_v1 .PipelineState .PIPELINE_STATE_SUCCEEDED
397+ )
398+
399+ @pytest .mark .parametrize (
400+ "job_spec_json" , [_TEST_PIPELINE_SPEC , _TEST_PIPELINE_JOB ],
336401 )
337402 def test_submit_call_pipeline_service_pipeline_job_create (
338403 self ,
@@ -359,8 +424,84 @@ def test_submit_call_pipeline_service_pipeline_job_create(
359424 job .submit (service_account = _TEST_SERVICE_ACCOUNT , network = _TEST_NETWORK )
360425
361426 expected_runtime_config_dict = {
362- "gcs_output_directory" : _TEST_GCS_BUCKET_NAME ,
363- "parameter_values" : {"string_param" : {"stringValue" : "hello" }},
427+ "gcsOutputDirectory" : _TEST_GCS_BUCKET_NAME ,
428+ "parameterValues" : {
429+ "bool_param" : {"boolValue" : True },
430+ "double_param" : {"numberValue" : 12.34 },
431+ "int_param" : {"numberValue" : 5678 },
432+ "list_int_param" : {"listValue" : [123 , 456 , 789 ]},
433+ "list_string_param" : {"listValue" : ["lorem" , "ipsum" ]},
434+ "struct_param" : {"structValue" : {"key1" : 12345 , "key2" : 67890 }},
435+ "string_param" : {"stringValue" : "hello world" },
436+ },
437+ }
438+ runtime_config = gca_pipeline_job_v1 .PipelineJob .RuntimeConfig ()._pb
439+ json_format .ParseDict (expected_runtime_config_dict , runtime_config )
440+
441+ pipeline_spec = job_spec_json .get ("pipelineSpec" ) or job_spec_json
442+
443+ # Construct expected request
444+ expected_gapic_pipeline_job = gca_pipeline_job_v1 .PipelineJob (
445+ display_name = _TEST_PIPELINE_JOB_DISPLAY_NAME ,
446+ pipeline_spec = {
447+ "components" : {},
448+ "pipelineInfo" : pipeline_spec ["pipelineInfo" ],
449+ "root" : pipeline_spec ["root" ],
450+ "schemaVersion" : "2.1.0" ,
451+ },
452+ runtime_config = runtime_config ,
453+ service_account = _TEST_SERVICE_ACCOUNT ,
454+ network = _TEST_NETWORK ,
455+ )
456+
457+ mock_pipeline_service_create .assert_called_once_with (
458+ parent = _TEST_PARENT ,
459+ pipeline_job = expected_gapic_pipeline_job ,
460+ pipeline_job_id = _TEST_PIPELINE_JOB_ID ,
461+ )
462+
463+ assert not mock_pipeline_service_get .called
464+
465+ job .wait ()
466+
467+ mock_pipeline_service_get .assert_called_with (
468+ name = _TEST_PIPELINE_JOB_NAME , retry = base ._DEFAULT_RETRY
469+ )
470+
471+ assert job ._gca_resource == make_pipeline_job (
472+ gca_pipeline_state_v1 .PipelineState .PIPELINE_STATE_SUCCEEDED
473+ )
474+
475+ @pytest .mark .parametrize (
476+ "job_spec_json" , [_TEST_PIPELINE_SPEC_LEGACY , _TEST_PIPELINE_JOB_LEGACY ],
477+ )
478+ def test_submit_call_pipeline_service_pipeline_job_create_legacy (
479+ self ,
480+ mock_pipeline_service_create ,
481+ mock_pipeline_service_get ,
482+ job_spec_json ,
483+ mock_load_json ,
484+ ):
485+ aiplatform .init (
486+ project = _TEST_PROJECT ,
487+ staging_bucket = _TEST_GCS_BUCKET_NAME ,
488+ location = _TEST_LOCATION ,
489+ credentials = _TEST_CREDENTIALS ,
490+ )
491+
492+ job = pipeline_jobs .PipelineJob (
493+ display_name = _TEST_PIPELINE_JOB_DISPLAY_NAME ,
494+ template_path = _TEST_TEMPLATE_PATH ,
495+ job_id = _TEST_PIPELINE_JOB_ID ,
496+ parameter_values = _TEST_PIPELINE_PARAMETER_VALUES_LEGACY ,
497+ enable_caching = True ,
498+ )
499+
500+ job .submit (service_account = _TEST_SERVICE_ACCOUNT , network = _TEST_NETWORK )
501+
502+ expected_runtime_config_dict = {
503+ "parameters" : {"string_param" : {"stringValue" : "hello" }},
504+ "gcsOutputDirectory" : _TEST_GCS_BUCKET_NAME ,
364505 }
365506 runtime_config = gca_pipeline_job_v1 .PipelineJob .RuntimeConfig ()._pb
366507 json_format .ParseDict (expected_runtime_config_dict , runtime_config )
@@ -374,6 +515,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(
374515 "components" : {},
375516 "pipelineInfo" : pipeline_spec ["pipelineInfo" ],
376517 "root" : pipeline_spec ["root" ],
518+ "schemaVersion" : "2.0.0" ,
377519 },
378520 runtime_config = runtime_config ,
379521 service_account = _TEST_SERVICE_ACCOUNT ,
@@ -508,13 +650,7 @@ def test_cancel_pipeline_job_without_running(
508650 "mock_pipeline_service_create" , "mock_pipeline_service_get_with_fail" ,
509651 )
510652 @pytest .mark .parametrize (
511- "job_spec_json" ,
512- [
513- _TEST_PIPELINE_SPEC ,
514- _TEST_PIPELINE_JOB ,
515- _TEST_PIPELINE_SPEC_LEGACY ,
516- _TEST_PIPELINE_JOB_LEGACY ,
517- ],
653+ "job_spec_json" , [_TEST_PIPELINE_SPEC , _TEST_PIPELINE_JOB ],
518654 )
519655 @pytest .mark .parametrize ("sync" , [True , False ])
520656 def test_pipeline_failure_raises (self , mock_load_json , sync ):
0 commit comments