@@ -92,6 +92,7 @@ def tensorflow_estimator():
9292
9393 estimator .sagemaker_session  =  MagicMock ()
9494 estimator .sagemaker_session .boto_region_name  =  'us-east-1' 
95+  estimator .sagemaker_session ._default_bucket  =  'sagemaker' 
9596
9697 return  estimator 
9798
@@ -289,6 +290,38 @@ def test_get_expected_model(pca_estimator):
289290 'End' : True 
290291 }
291292
293+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call ) 
294+ def  test_get_expected_model_with_framework_estimator (tensorflow_estimator ):
295+  training_step  =  TrainingStep ('Training' ,
296+  estimator = tensorflow_estimator ,
297+  data = {'train' : 's3://sagemaker/train' },
298+  job_name = 'tensorflow-job' ,
299+  mini_batch_size = 1024 
300+  )
301+  expected_model  =  training_step .get_expected_model ()
302+  expected_model .entry_point  =  'tf_train.py' 
303+  model_step  =  ModelStep ('Create model' , model = expected_model , model_name = 'tf-model' )
304+  assert  model_step .to_dict () ==  {
305+  'Type' : 'Task' ,
306+  'Parameters' : {
307+  'ExecutionRoleArn' : EXECUTION_ROLE ,
308+  'ModelName' : 'tf-model' ,
309+  'PrimaryContainer' : {
310+  'Environment' : {
311+  'SAGEMAKER_PROGRAM' : 'tf_train.py' ,
312+  'SAGEMAKER_SUBMIT_DIRECTORY' : 's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz' ,
313+  'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS' : 'false' ,
314+  'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20' ,
315+  'SAGEMAKER_REGION' : 'us-east-1' ,
316+  },
317+  'Image' : expected_model .image ,
318+  'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']" 
319+  }
320+  },
321+  'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
322+  'End' : True 
323+  }
324+ 
292325def  test_model_step_creation (pca_model ):
293326 step  =  ModelStep ('Create model' , model = pca_model , model_name = 'pca-model' )
294327 assert  step .to_dict () ==  {
0 commit comments