@@ -92,6 +92,7 @@ def tensorflow_estimator():
9292
9393 estimator .sagemaker_session = MagicMock ()
9494 estimator .sagemaker_session .boto_region_name = 'us-east-1'
95+ estimator .sagemaker_session ._default_bucket = 'sagemaker'
9596
9697 return estimator
9798
@@ -289,6 +290,38 @@ def test_get_expected_model(pca_estimator):
289290 'End' : True
290291 }
291292
293+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
294+ def test_get_expected_model_with_framework_estimator (tensorflow_estimator ):
295+ training_step = TrainingStep ('Training' ,
296+ estimator = tensorflow_estimator ,
297+ data = {'train' : 's3://sagemaker/train' },
298+ job_name = 'tensorflow-job' ,
299+ mini_batch_size = 1024
300+ )
301+ expected_model = training_step .get_expected_model ()
302+ expected_model .entry_point = 'tf_train.py'
303+ model_step = ModelStep ('Create model' , model = expected_model , model_name = 'tf-model' )
304+ assert model_step .to_dict () == {
305+ 'Type' : 'Task' ,
306+ 'Parameters' : {
307+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
308+ 'ModelName' : 'tf-model' ,
309+ 'PrimaryContainer' : {
310+ 'Environment' : {
311+ 'SAGEMAKER_PROGRAM' : 'tf_train.py' ,
312+ 'SAGEMAKER_SUBMIT_DIRECTORY' : 's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz' ,
313+ 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS' : 'false' ,
314+ 'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20' ,
315+ 'SAGEMAKER_REGION' : 'us-east-1' ,
316+ },
317+ 'Image' : expected_model .image ,
318+ 'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']"
319+ }
320+ },
321+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
322+ 'End' : True
323+ }
324+
292325def test_model_step_creation (pca_model ):
293326 step = ModelStep ('Create model' , model = pca_model , model_name = 'pca-model' )
294327 assert step .to_dict () == {
0 commit comments