66#
77# http://www.apache.org/licenses/LICENSE-2.0
88#
9- # or in the "license" file accompanying this file. This file is distributed
10- # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11- # express or implied. See the License for the specific language governing
9+ # or in the "license" file accompanying this file. This file is distributed
10+ # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+ # express or implied. See the License for the specific language governing
1212# permissions and limitations under the License.
1313from __future__ import absolute_import
1414
1515import pytest
1616import json
1717
1818from sagemaker .utils import unique_name_from_base
19- from sagemaker .image_uris import retrieve
19+ from sagemaker .image_uris import retrieve
2020from stepfunctions import steps
2121from stepfunctions .workflow import Workflow
2222from stepfunctions .steps .utils import get_aws_partition
2525
2626@pytest .fixture (scope = "module" )
2727def training_job_parameters (sagemaker_session , sagemaker_role_arn , record_set_fixture ):
28- parameters = {
29- "AlgorithmSpecification" : {
28+ parameters = {
29+ "AlgorithmSpecification" : {
3030 "TrainingImage" : retrieve (region = sagemaker_session .boto_session .region_name , framework = 'pca' ),
3131 "TrainingInputMode" : "File"
3232 },
33- "OutputDataConfig" : {
33+ "OutputDataConfig" : {
3434 "S3OutputPath" : "s3://{}/" .format (sagemaker_session .default_bucket ())
3535 },
36- "StoppingCondition" : {
36+ "StoppingCondition" : {
3737 "MaxRuntimeInSeconds" : 86400
3838 },
39- "ResourceConfig" : {
39+ "ResourceConfig" : {
4040 "InstanceCount" : 1 ,
4141 "InstanceType" : "ml.m5.large" ,
4242 "VolumeSizeInGB" : 30
4343 },
4444 "RoleArn" : sagemaker_role_arn ,
45- "InputDataConfig" :[
46- {
47- "DataSource" : {
48- "S3DataSource" : {
45+ "InputDataConfig" :[
46+ {
47+ "DataSource" : {
48+ "S3DataSource" : {
4949 "S3DataDistributionType" : "ShardedByS3Key" ,
5050 "S3DataType" : "ManifestFile" ,
5151 "S3Uri" : record_set_fixture .s3_data
@@ -54,7 +54,7 @@ def training_job_parameters(sagemaker_session, sagemaker_role_arn, record_set_fi
5454 "ChannelName" : "train"
5555 }
5656 ],
57- "HyperParameters" : {
57+ "HyperParameters" : {
5858 "num_components" : "48" ,
5959 "feature_dim" : "784" ,
6060 "mini_batch_size" : "200"
@@ -93,7 +93,7 @@ def test_pass_state_machine_creation(sfn_client, sfn_role_arn):
9393
9494 definition = steps .Pass (pass_state_name , result = pass_state_result )
9595 workflow = Workflow (
96- 'Test_Pass_Workflow' ,
96+ unique_name_from_base ( 'Test_Pass_Workflow' ) ,
9797 definition = definition ,
9898 role = sfn_role_arn
9999 )
@@ -164,7 +164,7 @@ def test_wait_state_machine_creation(sfn_client, sfn_role_arn):
164164 ])
165165
166166 workflow = Workflow (
167- 'Test_Wait_Workflow' ,
167+ unique_name_from_base ( 'Test_Wait_Workflow' ) ,
168168 definition = definition ,
169169 role = sfn_role_arn
170170 )
@@ -223,7 +223,7 @@ def test_parallel_state_machine_creation(sfn_client, sfn_role_arn):
223223 ])
224224
225225 workflow = Workflow (
226- 'Test_Parallel_Workflow' ,
226+ unique_name_from_base ( 'Test_Parallel_Workflow' ) ,
227227 definition = definition ,
228228 role = sfn_role_arn
229229 )
@@ -269,9 +269,9 @@ def test_map_state_machine_creation(sfn_client, sfn_role_arn):
269269 }
270270
271271 map_state = steps .Map (
272- map_state_name ,
272+ map_state_name ,
273273 items_path = items_path ,
274- iterator = steps .Pass (iterated_state_name ),
274+ iterator = steps .Pass (iterated_state_name ),
275275 max_concurrency = max_concurrency )
276276
277277 definition = steps .Chain ([
@@ -280,7 +280,7 @@ def test_map_state_machine_creation(sfn_client, sfn_role_arn):
280280 ])
281281
282282 workflow = Workflow (
283- 'Test_Map_Workflow' ,
283+ unique_name_from_base ( 'Test_Map_Workflow' ) ,
284284 definition = definition ,
285285 role = sfn_role_arn
286286 )
@@ -345,8 +345,8 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn):
345345
346346 definition .default_choice (
347347 steps .Fail (
348- default_state_name ,
349- error = default_error ,
348+ default_state_name ,
349+ error = default_error ,
350350 cause = default_cause
351351 )
352352 )
@@ -356,23 +356,23 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn):
356356 value = first_choice_value
357357 ),
358358 steps .Pass (
359- first_match_name ,
359+ first_match_name ,
360360 result = first_choice_state_result
361361 )
362362 )
363363 definition .add_choice (
364364 steps .ChoiceRule .NumericEquals (
365- variable = variable ,
365+ variable = variable ,
366366 value = second_choice_value
367- ),
367+ ),
368368 steps .Pass (
369- second_match_name ,
369+ second_match_name ,
370370 result = second_choice_state_result
371371 )
372372 )
373373
374374 workflow = Workflow (
375- 'Test_Choice_Workflow' ,
375+ unique_name_from_base ( 'Test_Choice_Workflow' ) ,
376376 definition = definition ,
377377 role = sfn_role_arn
378378 )
@@ -385,10 +385,10 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para
385385 final_state_name = "FinalState"
386386 resource = f"arn:{ get_aws_partition ()} :states:::sagemaker:createTrainingJob.sync"
387387 task_state_result = "Task State Result"
388- asl_state_machine_definition = {
388+ asl_state_machine_definition = {
389389 "StartAt" : task_state_name ,
390- "States" : {
391- task_state_name : {
390+ "States" : {
391+ task_state_name : {
392392 "Resource" : resource ,
393393 "Parameters" : training_job_parameters ,
394394 "Type" : "Task" ,
@@ -410,9 +410,9 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para
410410 ),
411411 steps .Pass (final_state_name , result = task_state_result )
412412 ])
413-
413+
414414 workflow = Workflow (
415- 'Test_Task_Workflow' ,
415+ unique_name_from_base ( 'Test_Task_Workflow' ) ,
416416 definition = definition ,
417417 role = sfn_role_arn
418418 )
@@ -465,13 +465,13 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
465465 )
466466 task .add_catch (
467467 steps .Catch (
468- error_equals = [all_fail_error ],
468+ error_equals = [all_fail_error ],
469469 next_step = steps .Pass (all_error_state_name , result = catch_state_result )
470470 )
471471 )
472472
473473 workflow = Workflow (
474- 'Test_Catch_Workflow' ,
474+ unique_name_from_base ( 'Test_Catch_Workflow' ) ,
475475 definition = task ,
476476 role = sfn_role_arn
477477 )
@@ -518,15 +518,15 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
518518
519519 task .add_retry (
520520 steps .Retry (
521- error_equals = [all_fail_error ],
522- interval_seconds = interval_seconds ,
523- max_attempts = max_attempts ,
521+ error_equals = [all_fail_error ],
522+ interval_seconds = interval_seconds ,
523+ max_attempts = max_attempts ,
524524 backoff_rate = backoff_rate
525525 )
526526 )
527527
528528 workflow = Workflow (
529- 'Test_Retry_Workflow' ,
529+ unique_name_from_base ( 'Test_Retry_Workflow' ) ,
530530 definition = task ,
531531 role = sfn_role_arn
532532 )
0 commit comments