Skip to content

Commit 53e38d9

Browse files
committed
Adding debug hook and rules configuration with tests
1 parent a305498 commit 53e38d9

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6666
else:
6767
parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
6868

69+
if estimator.debugger_hook_config != None:
70+
parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
71+
72+
if estimator.rules != None:
73+
parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
74+
6975
if isinstance(job_name, (ExecutionInput, StepInput)):
7076
parameters['TrainingJobName'] = job_name
7177

tests/unit/test_sagemaker_steps.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.tensorflow import TensorFlow
2222
from sagemaker.pipeline import PipelineModel
2323
from sagemaker.model_monitor import DataCaptureConfig
24+
from sagemaker.debugger import Rule, rule_configs, DebuggerHookConfig, CollectionConfig
2425

2526
from unittest.mock import MagicMock, patch
2627
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep
@@ -58,6 +59,54 @@ def pca_estimator():
5859

5960
return pca
6061

62+
@pytest.fixture
63+
def pca_estimator_with_debug_hook():
64+
s3_output_location = 's3://sagemaker/models'
65+
66+
hook_config = DebuggerHookConfig(
67+
s3_output_path='s3://sagemaker/output/debug',
68+
hook_parameters={
69+
"save_interval": "1"
70+
},
71+
collection_configs=[
72+
CollectionConfig("hyperparameters"),
73+
CollectionConfig("metrics")
74+
]
75+
)
76+
77+
rules = [Rule.sagemaker(rule_configs.confusion(),
78+
rule_parameters={
79+
"category_no": "15",
80+
"min_diag": "0.7",
81+
"max_off_diag": "0.3",
82+
"start_step": "17",
83+
"end_step": "19"}
84+
)]
85+
86+
pca = sagemaker.estimator.Estimator(
87+
PCA_IMAGE,
88+
role=EXECUTION_ROLE,
89+
train_instance_count=1,
90+
train_instance_type='ml.c4.xlarge',
91+
output_path=s3_output_location,
92+
debugger_hook_config = hook_config,
93+
rules=rules
94+
)
95+
96+
pca.set_hyperparameters(
97+
feature_dim=50000,
98+
num_components=10,
99+
subtract_mean=True,
100+
algorithm_mode='randomized',
101+
mini_batch_size=200
102+
)
103+
104+
pca.sagemaker_session = MagicMock()
105+
pca.sagemaker_session.boto_region_name = 'us-east-1'
106+
pca.sagemaker_session._default_bucket = 'sagemaker'
107+
108+
return pca
109+
61110
@pytest.fixture
62111
def pca_model():
63112
model_data = 's3://sagemaker/models/pca.tar.gz'
@@ -148,6 +197,65 @@ def test_training_step_creation(pca_estimator):
148197
'End': True
149198
}
150199

200+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
201+
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):
202+
step = TrainingStep('Training',
203+
estimator=pca_estimator_with_debug_hook,
204+
job_name='TrainingJob')
205+
assert step.to_dict() == {
206+
'Type': 'Task',
207+
'Parameters': {
208+
'AlgorithmSpecification': {
209+
'TrainingImage': PCA_IMAGE,
210+
'TrainingInputMode': 'File'
211+
},
212+
'OutputDataConfig': {
213+
'S3OutputPath': 's3://sagemaker/models'
214+
},
215+
'StoppingCondition': {
216+
'MaxRuntimeInSeconds': 86400
217+
},
218+
'ResourceConfig': {
219+
'InstanceCount': 1,
220+
'InstanceType': 'ml.c4.xlarge',
221+
'VolumeSizeInGB': 30
222+
},
223+
'RoleArn': EXECUTION_ROLE,
224+
'HyperParameters': {
225+
'feature_dim': '50000',
226+
'num_components': '10',
227+
'subtract_mean': 'True',
228+
'algorithm_mode': 'randomized',
229+
'mini_batch_size': '200'
230+
},
231+
'DebugHookConfig': {
232+
'S3OutputPath': 's3://sagemaker/output/debug',
233+
'HookParameters': {'save_interval': '1'},
234+
'CollectionConfigurations': [
235+
{'CollectionName': 'hyperparameters'},
236+
{'CollectionName': 'metrics'}
237+
]
238+
},
239+
'DebugRuleConfigurations': [
240+
{
241+
'RuleConfigurationName': 'Confusion',
242+
'RuleEvaluatorImage': '503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest',
243+
'RuleParameters': {
244+
'rule_to_invoke': 'Confusion',
245+
'category_no': '15',
246+
'min_diag': '0.7',
247+
'max_off_diag': '0.3',
248+
'start_step': '17',
249+
'end_step': '19'
250+
}
251+
}
252+
],
253+
'TrainingJobName': 'TrainingJob'
254+
},
255+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
256+
'End': True
257+
}
258+
151259
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
152260
def test_training_step_creation_with_model(pca_estimator):
153261
training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')

0 commit comments

Comments
 (0)