@@ -68,6 +68,41 @@ def pca_estimator():
6868 return  pca 
6969
7070
71+ @pytest .fixture  
72+ def  pca_estimator_with_env ():
73+  s3_output_location  =  's3://sagemaker/models' 
74+ 
75+  pca  =  sagemaker .estimator .Estimator (
76+  PCA_IMAGE ,
77+  role = EXECUTION_ROLE ,
78+  instance_count = 1 ,
79+  instance_type = 'ml.c4.xlarge' ,
80+  output_path = s3_output_location ,
81+  environment = {
82+  'JobName' : "job_name" ,
83+  'ModelName' : "model_name" 
84+  },
85+  subnets = [
86+  'subnet-00000000000000000' ,
87+  'subnet-00000000000000001' 
88+  ]
89+  )
90+ 
91+  pca .set_hyperparameters (
92+  feature_dim = 50000 ,
93+  num_components = 10 ,
94+  subtract_mean = True ,
95+  algorithm_mode = 'randomized' ,
96+  mini_batch_size = 200 
97+  )
98+ 
99+  pca .sagemaker_session  =  MagicMock ()
100+  pca .sagemaker_session .boto_region_name  =  'us-east-1' 
101+  pca .sagemaker_session ._default_bucket  =  'sagemaker' 
102+ 
103+  return  pca 
104+ 
105+ 
71106@pytest .fixture  
72107def  pca_estimator_with_debug_hook ():
73108 s3_output_location  =  's3://sagemaker/models' 
@@ -156,6 +191,31 @@ def pca_model():
156191 )
157192
158193
194+ @pytest .fixture  
195+ def  pca_model_with_env ():
196+  model_data  =  's3://sagemaker/models/pca.tar.gz' 
197+  return  Model (
198+  model_data = model_data ,
199+  image_uri = PCA_IMAGE ,
200+  role = EXECUTION_ROLE ,
201+  name = 'pca-model' ,
202+  env = {
203+  'JobName' : "job_name" ,
204+  'ModelName' : "model_name" 
205+  },
206+  vpc_config = {
207+  "SecurityGroupIds" : ["sg-00000000000000000" ],
208+  "Subnets" : ["subnet-00000000000000000" , "subnet-00000000000000001" ]
209+  },
210+  image_config = {
211+  "RepositoryAccessMode" : "Vpc" ,
212+  "RepositoryAuthConfig" : {
213+  "RepositoryCredentialsProviderArn" : "arn" 
214+  }
215+  }
216+  )
217+ 
218+ 
159219@pytest .fixture  
160220def  pca_transformer (pca_model ):
161221 return  Transformer (
@@ -537,6 +597,63 @@ def test_training_step_creation_with_model(pca_estimator):
537597 }
538598
539599
600+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call ) 
601+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' ) 
602+ def  test_training_step_creation_with_model_with_env (pca_estimator_with_env ):
603+  training_step  =  TrainingStep ('Training' , estimator = pca_estimator_with_env , job_name = 'TrainingJob' )
604+  model_step  =  ModelStep ('Training - Save Model' , training_step .get_expected_model (model_name = training_step .output ()['TrainingJobName' ]))
605+  training_step .next (model_step )
606+  assert  training_step .to_dict () ==  {
607+  'Type' : 'Task' ,
608+  'Parameters' : {
609+  'AlgorithmSpecification' : {
610+  'TrainingImage' : PCA_IMAGE ,
611+  'TrainingInputMode' : 'File' 
612+  },
613+  'OutputDataConfig' : {
614+  'S3OutputPath' : 's3://sagemaker/models' 
615+  },
616+  'StoppingCondition' : {
617+  'MaxRuntimeInSeconds' : 86400 
618+  },
619+  'ResourceConfig' : {
620+  'InstanceCount' : 1 ,
621+  'InstanceType' : 'ml.c4.xlarge' ,
622+  'VolumeSizeInGB' : 30 
623+  },
624+  'RoleArn' : EXECUTION_ROLE ,
625+  'HyperParameters' : {
626+  'feature_dim' : '50000' ,
627+  'num_components' : '10' ,
628+  'subtract_mean' : 'True' ,
629+  'algorithm_mode' : 'randomized' ,
630+  'mini_batch_size' : '200' 
631+  },
632+  'TrainingJobName' : 'TrainingJob' 
633+  },
634+  'Resource' : 'arn:aws:states:::sagemaker:createTrainingJob.sync' ,
635+  'Next' : 'Training - Save Model' 
636+  }
637+ 
638+  assert  model_step .to_dict () ==  {
639+  'Type' : 'Task' ,
640+  'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
641+  'Parameters' : {
642+  'ExecutionRoleArn' : EXECUTION_ROLE ,
643+  'ModelName.$' : "$['TrainingJobName']" ,
644+  'PrimaryContainer' : {
645+  'Environment' : {
646+  'JobName' : 'job_name' ,
647+  'ModelName' : 'model_name' 
648+  },
649+  'Image' : PCA_IMAGE ,
650+  'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']" 
651+  }
652+  },
653+  'End' : True 
654+  }
655+ 
656+ 
540657@patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call ) 
541658@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' ) 
542659def  test_training_step_creation_with_framework (tensorflow_estimator ):
@@ -806,6 +923,31 @@ def test_get_expected_model(pca_estimator):
806923 }
807924
808925
926+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call ) 
927+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' ) 
928+ def  test_get_expected_model_with_env (pca_estimator_with_env ):
929+  training_step  =  TrainingStep ('Training' , estimator = pca_estimator_with_env , job_name = 'TrainingJob' )
930+  expected_model  =  training_step .get_expected_model ()
931+  model_step  =  ModelStep ('Create model' , model = expected_model , model_name = 'pca-model' )
932+  assert  model_step .to_dict () ==  {
933+  'Type' : 'Task' ,
934+  'Parameters' : {
935+  'ExecutionRoleArn' : EXECUTION_ROLE ,
936+  'ModelName' : 'pca-model' ,
937+  'PrimaryContainer' : {
938+  'Environment' : {
939+  'JobName' : 'job_name' ,
940+  'ModelName' : 'model_name' 
941+  },
942+  'Image' : expected_model .image_uri ,
943+  'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']" 
944+  }
945+  },
946+  'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
947+  'End' : True 
948+  }
949+ 
950+ 
809951@patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call ) 
810952@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' ) 
811953def  test_get_expected_model_with_framework_estimator (tensorflow_estimator ):
@@ -859,6 +1001,29 @@ def test_model_step_creation(pca_model):
8591001 }
8601002
8611003
1004+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' ) 
1005+ def  test_model_step_creation_with_env (pca_model_with_env ):
1006+  step  =  ModelStep ('Create model' , model = pca_model_with_env , model_name = 'pca-model' , tags = DEFAULT_TAGS )
1007+  assert  step .to_dict () ==  {
1008+  'Type' : 'Task' ,
1009+  'Parameters' : {
1010+  'ExecutionRoleArn' : EXECUTION_ROLE ,
1011+  'ModelName' : 'pca-model' ,
1012+  'PrimaryContainer' : {
1013+  'Environment' : {
1014+  'JobName' : 'job_name' ,
1015+  'ModelName' : 'model_name' 
1016+  },
1017+  'Image' : pca_model_with_env .image_uri ,
1018+  'ModelDataUrl' : pca_model_with_env .model_data 
1019+  },
1020+  'Tags' : DEFAULT_TAGS_LIST 
1021+  },
1022+  'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
1023+  'End' : True 
1024+  }
1025+ 
1026+ 
8621027@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' ) 
8631028def  test_endpoint_config_step_creation (pca_model ):
8641029 data_capture_config  =  DataCaptureConfig (
0 commit comments