Skip to content

Commit 8f50d38

Browse files
committed
fix test
1 parent 41e9fe3 commit 8f50d38

File tree

2 files changed

+206
-59
lines changed

2 files changed

+206
-59
lines changed

google/cloud/aiplatform/utils/pipeline_utils.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Dict, Mapping, Optional, Union
2020
import packaging.version
2121

22+
2223
class PipelineRuntimeConfigBuilder(object):
2324
"""Pipeline RuntimeConfig builder.
2425
@@ -47,7 +48,7 @@ def __init__(
4748
self._pipeline_root = pipeline_root
4849
self._parameter_types = parameter_types
4950
self._parameter_values = copy.deepcopy(parameter_values or {})
50-
self._schema_version = schema_version or '2.0.0'
51+
self._schema_version = schema_version or "2.0.0"
5152

5253
@classmethod
5354
def from_job_spec_json(
@@ -63,9 +64,15 @@ def from_job_spec_json(
6364
A PipelineRuntimeConfigBuilder object.
6465
"""
6566
runtime_config_spec = job_spec["runtimeConfig"]
66-
input_definitions = job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
67-
parameter_input_definitions = input_definitions.get("parameter_values") or input_definitions.get("parameters") or {}
68-
schema_version = job_spec.get('schemaVersion')
67+
input_definitions = (
68+
job_spec["pipelineSpec"]["root"].get("inputDefinitions") or {}
69+
)
70+
parameter_input_definitions = (
71+
input_definitions.get("parameterValues")
72+
or input_definitions.get("parameters")
73+
or {}
74+
)
75+
schema_version = job_spec["pipelineSpec"].get("schemaVersion")
6976

7077
# 'type' is deprecated in IR and change to 'parameterType'.
7178
parameter_types = {
@@ -98,9 +105,12 @@ def update_runtime_parameters(
98105
"""
99106
if parameter_values:
100107
parameters = dict(parameter_values)
101-
for k, v in parameter_values.items():
102-
if isinstance(v, (dict, list, bool)):
103-
parameters[k] = json.dumps(v)
108+
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
109+
"2.0.0"
110+
):
111+
for k, v in parameter_values.items():
112+
if isinstance(v, (dict, list, bool)):
113+
parameters[k] = json.dumps(v)
104114
self._parameter_values.update(parameters)
105115

106116
def build(self) -> Dict[str, Any]:
@@ -114,10 +124,12 @@ def build(self) -> Dict[str, Any]:
114124
"Pipeline root must be specified, either during "
115125
"compile time, or when calling the service."
116126
)
117-
if packaging.version.parse(self._schema_version) >= packaging.version.parse("2.1.0"):
118-
parameter_values_key = 'parameter_values'
127+
if packaging.version.parse(self._schema_version) > packaging.version.parse(
128+
"2.0.0"
129+
):
130+
parameter_values_key = "parameterValues"
119131
else:
120-
parameter_values_key = 'parameters'
132+
parameter_values_key = "parameters"
121133
return {
122134
"gcsOutputDirectory": self._pipeline_root,
123135
parameter_values_key: {
@@ -173,7 +185,6 @@ def _get_vertex_value(
173185
result["structValue"] = value
174186
else:
175187
raise TypeError("Got unknown type of value: {}".format(value))
176-
177188
return result
178189

179190

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 184 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -53,53 +53,50 @@
5353

5454
_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
5555

56-
_TEST_PIPELINE_PARAMETER_VALUES = {"string_param": "hello"}
56+
_TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"}
57+
_TEST_PIPELINE_PARAMETER_VALUES = {
58+
"string_param": "hello world",
59+
"bool_param": True,
60+
"double_param": 12.34,
61+
"int_param": 5678,
62+
"list_int_param": [123, 456, 789],
63+
"list_string_param": ["lorem", "ipsum"],
64+
"struct_param": {"key1": 12345, "key2": 67890},
65+
}
66+
5767
_TEST_PIPELINE_SPEC_LEGACY = {
5868
"pipelineInfo": {"name": "my-pipeline"},
5969
"root": {
6070
"dag": {"tasks": {}},
6171
"inputDefinitions": {"parameters": {"string_param": {"type": "STRING"}}},
6272
},
63-
"schema_version": "2.0.0",
73+
"schemaVersion": "2.0.0",
6474
"components": {},
6575
}
6676
_TEST_PIPELINE_SPEC = {
6777
"pipelineInfo": {"name": "my-pipeline"},
6878
"root": {
6979
"dag": {"tasks": {}},
7080
"inputDefinitions": {
71-
"parameter_values": {
81+
"parameterValues": {
7282
"string_param": {"parameterType": "STRING"},
73-
"bool_param": {
74-
"parameterType": "BOOLEAN"
75-
},
76-
"double_param": {
77-
"parameterType": "NUMBER_DOUBLE"
78-
},
79-
"int_param": {
80-
"parameterType": "NUMBER_INTEGER"
81-
},
82-
"list_int_param": {
83-
"parameterType": "LIST"
84-
},
85-
"list_string_param": {
86-
"parameterType": "LIST"
87-
},
88-
"struct_param": {
89-
"parameterType": "STRUCT"
90-
}
83+
"bool_param": {"parameterType": "BOOLEAN"},
84+
"double_param": {"parameterType": "NUMBER_DOUBLE"},
85+
"int_param": {"parameterType": "NUMBER_INTEGER"},
86+
"list_int_param": {"parameterType": "LIST"},
87+
"list_string_param": {"parameterType": "LIST"},
88+
"struct_param": {"parameterType": "STRUCT"},
9189
}
9290
},
9391
},
94-
"schema_version":"2.1.0",
92+
"schemaVersion": "2.1.0",
9593
"components": {},
9694
}
9795

9896
_TEST_PIPELINE_JOB_LEGACY = {
9997
"runtimeConfig": {},
10098
"pipelineSpec": _TEST_PIPELINE_SPEC_LEGACY,
10199
}
102-
103100
_TEST_PIPELINE_JOB = {
104101
"runtimeConfig": {
105102
"parameterValues": {
@@ -109,7 +106,7 @@
109106
"int_param": 5678,
110107
"list_int_param": [123, 456, 789],
111108
"list_string_param": ["lorem", "ipsum"],
112-
"struct_param": { "key1": 12345, "key2": 67890}
109+
"struct_param": {"key1": 12345, "key2": 67890},
113110
},
114111
},
115112
"pipelineSpec": _TEST_PIPELINE_SPEC,
@@ -250,13 +247,7 @@ def teardown_method(self):
250247
initializer.global_pool.shutdown(wait=True)
251248

252249
@pytest.mark.parametrize(
253-
"job_spec_json",
254-
[
255-
_TEST_PIPELINE_SPEC,
256-
_TEST_PIPELINE_JOB,
257-
_TEST_PIPELINE_SPEC_LEGACY,
258-
_TEST_PIPELINE_JOB_LEGACY,
259-
],
250+
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
260251
)
261252
@pytest.mark.parametrize("sync", [True, False])
262253
def test_run_call_pipeline_service_create(
@@ -291,7 +282,15 @@ def test_run_call_pipeline_service_create(
291282

292283
expected_runtime_config_dict = {
293284
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
294-
"parameter_values": {"string_param": {"stringValue": "hello"}},
285+
"parameterValues": {
286+
"bool_param": {"boolValue": True},
287+
"double_param": {"numberValue": 12.34},
288+
"int_param": {"numberValue": 5678},
289+
"list_int_param": {"listValue": [123, 456, 789]},
290+
"list_string_param": {"listValue": ["lorem", "ipsum"]},
291+
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
292+
"string_param": {"stringValue": "hello world"},
293+
},
295294
}
296295
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
297296
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
@@ -305,6 +304,7 @@ def test_run_call_pipeline_service_create(
305304
"components": {},
306305
"pipelineInfo": pipeline_spec["pipelineInfo"],
307306
"root": pipeline_spec["root"],
307+
"schemaVersion": "2.1.0",
308308
},
309309
runtime_config=runtime_config,
310310
service_account=_TEST_SERVICE_ACCOUNT,
@@ -326,13 +326,78 @@ def test_run_call_pipeline_service_create(
326326
)
327327

328328
@pytest.mark.parametrize(
329-
"job_spec_json",
330-
[
331-
_TEST_PIPELINE_SPEC,
332-
_TEST_PIPELINE_JOB,
333-
_TEST_PIPELINE_SPEC_LEGACY,
334-
_TEST_PIPELINE_JOB_LEGACY,
335-
],
329+
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
330+
)
331+
@pytest.mark.parametrize("sync", [True, False])
332+
def test_run_call_pipeline_service_create_legacy(
333+
self,
334+
mock_pipeline_service_create,
335+
mock_pipeline_service_get,
336+
job_spec_json,
337+
mock_load_json,
338+
sync,
339+
):
340+
aiplatform.init(
341+
project=_TEST_PROJECT,
342+
staging_bucket=_TEST_GCS_BUCKET_NAME,
343+
location=_TEST_LOCATION,
344+
credentials=_TEST_CREDENTIALS,
345+
)
346+
347+
job = pipeline_jobs.PipelineJob(
348+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
349+
template_path=_TEST_TEMPLATE_PATH,
350+
job_id=_TEST_PIPELINE_JOB_ID,
351+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY,
352+
enable_caching=True,
353+
)
354+
355+
job.run(
356+
service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, sync=sync,
357+
)
358+
359+
if not sync:
360+
job.wait()
361+
362+
expected_runtime_config_dict = {
363+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
364+
"parameters": {"string_param": {"stringValue": "hello"}},
365+
}
366+
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
367+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
368+
369+
pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json
370+
371+
# Construct expected request
372+
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
373+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
374+
pipeline_spec={
375+
"components": {},
376+
"pipelineInfo": pipeline_spec["pipelineInfo"],
377+
"root": pipeline_spec["root"],
378+
"schemaVersion": "2.0.0",
379+
},
380+
runtime_config=runtime_config,
381+
service_account=_TEST_SERVICE_ACCOUNT,
382+
network=_TEST_NETWORK,
383+
)
384+
385+
mock_pipeline_service_create.assert_called_once_with(
386+
parent=_TEST_PARENT,
387+
pipeline_job=expected_gapic_pipeline_job,
388+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
389+
)
390+
391+
mock_pipeline_service_get.assert_called_with(
392+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
393+
)
394+
395+
assert job._gca_resource == make_pipeline_job(
396+
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
397+
)
398+
399+
@pytest.mark.parametrize(
400+
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
336401
)
337402
def test_submit_call_pipeline_service_pipeline_job_create(
338403
self,
@@ -359,8 +424,84 @@ def test_submit_call_pipeline_service_pipeline_job_create(
359424
job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)
360425

361426
expected_runtime_config_dict = {
362-
"gcs_output_directory": _TEST_GCS_BUCKET_NAME,
363-
"parameter_values": {"string_param": {"stringValue": "hello"}},
427+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
428+
"parameterValues": {
429+
"bool_param": {"boolValue": True},
430+
"double_param": {"numberValue": 12.34},
431+
"int_param": {"numberValue": 5678},
432+
"list_int_param": {"listValue": [123, 456, 789]},
433+
"list_string_param": {"listValue": ["lorem", "ipsum"]},
434+
"struct_param": {"structValue": {"key1": 12345, "key2": 67890}},
435+
"string_param": {"stringValue": "hello world"},
436+
},
437+
}
438+
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
439+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
440+
441+
pipeline_spec = job_spec_json.get("pipelineSpec") or job_spec_json
442+
443+
# Construct expected request
444+
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
445+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
446+
pipeline_spec={
447+
"components": {},
448+
"pipelineInfo": pipeline_spec["pipelineInfo"],
449+
"root": pipeline_spec["root"],
450+
"schemaVersion": "2.1.0",
451+
},
452+
runtime_config=runtime_config,
453+
service_account=_TEST_SERVICE_ACCOUNT,
454+
network=_TEST_NETWORK,
455+
)
456+
457+
mock_pipeline_service_create.assert_called_once_with(
458+
parent=_TEST_PARENT,
459+
pipeline_job=expected_gapic_pipeline_job,
460+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
461+
)
462+
463+
assert not mock_pipeline_service_get.called
464+
465+
job.wait()
466+
467+
mock_pipeline_service_get.assert_called_with(
468+
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
469+
)
470+
471+
assert job._gca_resource == make_pipeline_job(
472+
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
473+
)
474+
475+
@pytest.mark.parametrize(
476+
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
477+
)
478+
def test_submit_call_pipeline_service_pipeline_job_create_legacy(
479+
self,
480+
mock_pipeline_service_create,
481+
mock_pipeline_service_get,
482+
job_spec_json,
483+
mock_load_json,
484+
):
485+
aiplatform.init(
486+
project=_TEST_PROJECT,
487+
staging_bucket=_TEST_GCS_BUCKET_NAME,
488+
location=_TEST_LOCATION,
489+
credentials=_TEST_CREDENTIALS,
490+
)
491+
492+
job = pipeline_jobs.PipelineJob(
493+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
494+
template_path=_TEST_TEMPLATE_PATH,
495+
job_id=_TEST_PIPELINE_JOB_ID,
496+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES_LEGACY,
497+
enable_caching=True,
498+
)
499+
500+
job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)
501+
502+
expected_runtime_config_dict = {
503+
"parameters": {"string_param": {"stringValue": "hello"}},
504+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
364505
}
365506
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
366507
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
@@ -374,6 +515,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(
374515
"components": {},
375516
"pipelineInfo": pipeline_spec["pipelineInfo"],
376517
"root": pipeline_spec["root"],
518+
"schemaVersion": "2.0.0",
377519
},
378520
runtime_config=runtime_config,
379521
service_account=_TEST_SERVICE_ACCOUNT,
@@ -508,13 +650,7 @@ def test_cancel_pipeline_job_without_running(
508650
"mock_pipeline_service_create", "mock_pipeline_service_get_with_fail",
509651
)
510652
@pytest.mark.parametrize(
511-
"job_spec_json",
512-
[
513-
_TEST_PIPELINE_SPEC,
514-
_TEST_PIPELINE_JOB,
515-
_TEST_PIPELINE_SPEC_LEGACY,
516-
_TEST_PIPELINE_JOB_LEGACY,
517-
],
653+
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
518654
)
519655
@pytest.mark.parametrize("sync", [True, False])
520656
def test_pipeline_failure_raises(self, mock_load_json, sync):

0 commit comments

Comments
 (0)