Skip to content

Commit fe19e7b

Browse files
authored
Sagemaker debugger stalled training rule example (aws#1301)
* Sagemaker debugger stalled training rule example
1 parent b001fdd commit fe19e7b

File tree

2 files changed

+352
-0
lines changed

2 files changed

+352
-0
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Detect stalled training and stop training job using debugger rule\n",
8+
" \n",
9+
"\n",
10+
"In this notebook, we'll show you how you can use StalledTrainingRule rule which can take action like stopping your training job when it finds that there has been no update in training job for certain threshold duration.\n",
11+
"\n",
12+
"## How does StalledTrainingRule works?\n",
13+
"\n",
14+
"Amazon Sagemaker debugger automatically captures tensors from training job which use AWS DLC(tensorflow, pytorch, mxnet, xgboost)[refer doc for supported versions](https://github.com/awslabs/sagemaker-debugger/blob/master/docs/sagemaker.md#zero-script-change). StalledTrainingRule keeps watching on emission of tensors like loss. The execution happens outside of training containers. It is evident that if training job is running good and is not stalled it is expected to emit loss and metrics tensors at frequent intervals. If Rule doesn't find new tensors being emitted from training job for threshold period of time, it takes automatic action to issue StopTrainingJob.\n",
15+
"\n",
16+
"#### With no changes to your training script\n",
17+
"If you use one of the SageMaker provided [Deep Learning Containers](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html). [Refer doc for supported framework versions](https://github.com/awslabs/sagemaker-debugger/blob/master/docs/sagemaker.md#zero-script-change), then you don't need to make any changes to your training script for activating this rule. Loss tensors will automatically be captured and monitored by the rule.\n",
18+
"\n",
19+
"You can also emit tensors periodically by using [save scalar api of hook](https://github.com/awslabs/sagemaker-debugger/blob/master/docs/api.md#common-hook-api) . \n",
20+
"\n",
21+
"Also look at example how to use save_scalar api [here](https://github.com/awslabs/sagemaker-debugger/blob/master/examples/tensorflow2/scripts/tf_keras_fit_non_eager.py#L42)"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"! pip install -q sagemaker"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"import boto3\n",
40+
"import os\n",
41+
"import sagemaker\n",
42+
"from sagemaker.tensorflow import TensorFlow\n",
43+
"print(sagemaker.__version__)"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"from sagemaker.debugger import Rule, DebuggerHookConfig, TensorBoardOutputConfig, CollectionConfig\n",
53+
"import smdebug_rulesconfig as rule_configs"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"# define the entrypoint script\n",
63+
"# Below script has 5 minutes sleep, we will create a stalledTrainingRule with 3 minutes of threshold.\n",
64+
"entrypoint_script='src/simple_stalled_training.py'\n",
65+
"\n",
66+
"# these hyperparameters ensure that vanishing gradient will trigger for our tensorflow mnist script\n",
67+
"hyperparameters = {\n",
68+
" \"num_epochs\": \"10\",\n",
69+
" \"lr\": \"10.00\"\n",
70+
"}"
71+
]
72+
},
73+
{
74+
"cell_type": "markdown",
75+
"metadata": {},
76+
"source": [
77+
"### Create unique training job prefix\n",
78+
"We will create unique training job name prefix. this prefix would be passed to StalledTrainingRule to identify which training job, rule should take action on once the stalled training rule condition is met.\n",
79+
"Note that, this prefix needs to be unique. If rule doesn't find exactly one job with provided prefix, it will fallback to safe mode and not take action of stop training job. Rule will still emit a cloudwatch event if the rule condition is met. To see details about cloud watch event, check [here](https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-debugger/tensorflow_action_on_rule/tf-mnist-stop-training-job.ipynb). "
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
88+
"import time\n",
89+
"print(int(time.time()))\n",
90+
"# Note that sagemaker appends date to your training job and truncates the provided name to 39 character. So, we will make \n",
91+
"# sure that we use less than 39 character in below prefix. Appending time is to provide a unique id\n",
92+
"base_job_name_prefix= 'smdebug-stalled-demo-' + str(int(time.time()))\n",
93+
"base_job_name_prefix = base_job_name_prefix[:34]\n",
94+
"print(base_job_name_prefix)"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": null,
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"stalled_training_job_rule = Rule.sagemaker(\n",
104+
" base_config={\n",
105+
" 'DebugRuleConfiguration': {\n",
106+
" 'RuleConfigurationName': 'StalledTrainingRule', \n",
107+
" 'RuleParameters': {'rule_to_invoke': 'StalledTrainingRule'}\n",
108+
" }\n",
109+
" },\n",
110+
" rule_parameters={\n",
111+
" 'threshold': '120',\n",
112+
" 'training_job_name_prefix': base_job_name_prefix,\n",
113+
" 'stop_training_on_fire' : 'True'\n",
114+
" }, \n",
115+
")"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"estimator = TensorFlow(\n",
125+
" role=sagemaker.get_execution_role(),\n",
126+
" base_job_name=base_job_name_prefix,\n",
127+
" train_instance_count=1,\n",
128+
" train_instance_type='ml.m5.4xlarge',\n",
129+
" entry_point=entrypoint_script,\n",
130+
" #source_dir = 'src',\n",
131+
" framework_version='1.15.0',\n",
132+
" py_version='py3',\n",
133+
" train_max_run=3600,\n",
134+
" script_mode=True,\n",
135+
" ## New parameter\n",
136+
" rules = [stalled_training_job_rule]\n",
137+
")\n"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"# After calling fit, SageMaker will spin off 1 training job and 1 rule job for you\n",
147+
"# The rule evaluation status(es) will be visible in the training logs\n",
148+
"# at regular intervals\n",
149+
"# wait=False makes this a fire and forget function. To stream the logs in the notebook leave this out\n",
150+
"\n",
151+
"estimator.fit(wait=True)"
152+
]
153+
},
154+
{
155+
"cell_type": "markdown",
156+
"metadata": {},
157+
"source": [
158+
"## Monitoring\n",
159+
"\n",
160+
"SageMaker kicked off rule evaluation job `StalledTrainingRule` as specified in the estimator. \n",
161+
"Given that we've stalled our training script for 10 minutes such that `StalledTrainingRule` is bound to fire and take action StopTrainingJob, we should expect to see the `TrainingJobStatus` as\n",
162+
"`Stopped` once the `RuleEvaluationStatus` for `StalledTrainingRule` changes to `IssuesFound`"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": null,
168+
"metadata": {},
169+
"outputs": [],
170+
"source": [
171+
"# rule job summary gives you the summary of the rule evaluations. You might have to run it over \n",
172+
"# a few times before you start to see all values populated/changing\n",
173+
"estimator.latest_training_job.rule_job_summary()"
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": null,
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"# This utility gives the link to monitor the CW event\n",
183+
"def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):\n",
184+
" \"\"\"Helper function to get the rule job name\"\"\"\n",
185+
" return \"{}-{}-{}\".format(\n",
186+
" training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]\n",
187+
" )\n",
188+
" \n",
189+
"def _get_cw_url_for_rule_job(rule_job_name, region):\n",
190+
" return \"https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix\".format(region, region, rule_job_name)\n",
191+
"\n",
192+
"\n",
193+
"def get_rule_jobs_cw_urls(estimator):\n",
194+
" region = boto3.Session().region_name\n",
195+
" training_job = estimator.latest_training_job\n",
196+
" training_job_name = training_job.describe()[\"TrainingJobName\"]\n",
197+
" rule_eval_statuses = training_job.describe()[\"DebugRuleEvaluationStatuses\"]\n",
198+
" \n",
199+
" result={}\n",
200+
" for status in rule_eval_statuses:\n",
201+
" if status.get(\"RuleEvaluationJobArn\", None) is not None:\n",
202+
" rule_job_name = _get_rule_job_name(training_job_name, status[\"RuleConfigurationName\"], status[\"RuleEvaluationJobArn\"])\n",
203+
" result[status[\"RuleConfigurationName\"]] = _get_cw_url_for_rule_job(rule_job_name, region)\n",
204+
" return result\n",
205+
"\n",
206+
"get_rule_jobs_cw_urls(estimator)"
207+
]
208+
},
209+
{
210+
"cell_type": "markdown",
211+
"metadata": {},
212+
"source": [
213+
"After running the last two cells over and until `VanishingGradient` reports `IssuesFound`, we'll attempt to describe the `TrainingJobStatus` for our training job."
214+
]
215+
},
216+
{
217+
"cell_type": "code",
218+
"execution_count": null,
219+
"metadata": {},
220+
"outputs": [],
221+
"source": [
222+
"estimator.latest_training_job.describe()[\"TrainingJobStatus\"]"
223+
]
224+
},
225+
{
226+
"cell_type": "markdown",
227+
"metadata": {},
228+
"source": [
229+
"## Result\n",
230+
"\n",
231+
"This notebook attempted to show a very simple setup of how you can use CloudWatch events for your training job to take action on rule evaluation status changes. Learn more about Amazon SageMaker Debugger in the [GitHub Documentation](https://github.com/awslabs/sagemaker-debugger)."
232+
]
233+
}
234+
],
235+
"metadata": {
236+
"kernelspec": {
237+
"display_name": "Python 3",
238+
"language": "python",
239+
"name": "python3"
240+
},
241+
"language_info": {
242+
"codemirror_mode": {
243+
"name": "ipython",
244+
"version": 3
245+
},
246+
"file_extension": ".py",
247+
"mimetype": "text/x-python",
248+
"name": "python",
249+
"nbconvert_exporter": "python",
250+
"pygments_lexer": "ipython3",
251+
"version": "3.6.10"
252+
}
253+
},
254+
"nbformat": 4,
255+
"nbformat_minor": 4
256+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
This script is a simple training script which uses Tensorflow's MonitoredSession interface.
3+
It is designed to be used with SageMaker Debugger in an official SageMaker Framework container (i.e. AWS Deep Learning Container).
4+
Here we create the hook object which loads configuration from the json file that SageMaker will
5+
put in the training container based on the configuration provided using the SageMaker python SDK when creating a job.
6+
We use this hook object here to add our custom loss to the losses collection and set the mode.
7+
For more information, please refer to https://github.com/awslabs/sagemaker-debugger/blob/master/docs/
8+
"""
9+
10+
# Standard Library
11+
import argparse
12+
import random
13+
14+
# Third Party
15+
import numpy as np
16+
import tensorflow.compat.v1 as tf
17+
18+
# First Party
19+
import smdebug.tensorflow as smd
20+
21+
22+
def str2bool(v):
23+
if isinstance(v, bool):
24+
return v
25+
if v.lower() in ("yes", "true", "t", "y", "1"):
26+
return True
27+
elif v.lower() in ("no", "false", "f", "n", "0"):
28+
return False
29+
else:
30+
raise argparse.ArgumentTypeError("Boolean value expected.")
31+
32+
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--model_dir", type=str, help="S3 path for the model")
35+
parser.add_argument("--lr", type=float, help="Learning Rate", default=0.001)
36+
parser.add_argument("--steps", type=int, help="Number of steps to run", default=100)
37+
parser.add_argument("--scale", type=float, help="Scaling factor for inputs", default=1.0)
38+
parser.add_argument("--random_seed", type=bool, default=False)
39+
args = parser.parse_args()
40+
41+
# these random seeds are only intended for test purpose.
42+
# for now, 2,2,12 could promise no assert failure when running tests
43+
# if you wish to change the number, notice that certain steps' tensor value may be capable of variation
44+
if args.random_seed:
45+
tf.set_random_seed(2)
46+
np.random.seed(2)
47+
random.seed(12)
48+
49+
# Network definition
50+
# Note the use of name scopes
51+
with tf.name_scope("foobar"):
52+
x = tf.placeholder(shape=(None, 2), dtype=tf.float32)
53+
w = tf.Variable(initial_value=[[10.0], [10.0]], name="weight1")
54+
with tf.name_scope("foobaz"):
55+
w0 = [[1], [1.0]]
56+
y = tf.matmul(x, w0)
57+
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2, name="loss")
58+
59+
hook = smd.SessionHook.create_from_json_file()
60+
hook.add_to_collection("losses", loss)
61+
62+
global_step = tf.Variable(17, name="global_step", trainable=False)
63+
increment_global_step_op = tf.assign(global_step, global_step + 1)
64+
65+
optimizer = tf.train.AdamOptimizer(args.lr)
66+
67+
# Do not need to wrap the optimizer if in a zero script change environment
68+
# i.e. SageMaker/AWS Deep Learning Containers
69+
# as the framework will automatically do that there if the hook exists
70+
optimizer = hook.wrap_optimizer(optimizer)
71+
72+
# use this wrapped optimizer to minimize loss
73+
optimizer_op = optimizer.minimize(loss, global_step=increment_global_step_op)
74+
75+
# Do not need to pass the hook to the session if in a zero script change environment
76+
# i.e. SageMaker/AWS Deep Learning Containers
77+
# as the framework will automatically do that there if the hook exists
78+
sess = tf.train.MonitoredSession()
79+
80+
# use this session for running the tensorflow model
81+
hook.set_mode(smd.modes.TRAIN)
82+
for i in range(args.steps):
83+
x_ = np.random.random((10, 2)) * args.scale
84+
_loss, opt, gstep = sess.run([loss, optimizer_op, increment_global_step_op], {x: x_})
85+
print(f"Step={i}, Loss={_loss}")
86+
87+
# set the mode for monitored session based runs
88+
# so smdebug can separate out steps by mode
89+
hook.set_mode(smd.modes.EVAL)
90+
for i in range(args.steps):
91+
x_ = np.random.random((10, 2)) * args.scale
92+
sess.run([loss, increment_global_step_op], {x: x_})
93+
import time
94+
print("Sleeping for 10 minutes")
95+
time.sleep(10*60)
96+
print("Waking up and exiting")

0 commit comments

Comments
 (0)