Skip to content

Commit 0d59934

Browse files
committed
Add TF quickstart
1 parent d887135 commit 0d59934

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Using TensorFlow in SageMaker - Quickstart\n",
8+
"\n",
9+
"Starting by the TensorFlow's framework version 1.11, you can use the SageMaker TensorFlow Container to train any TensorFlow script. \n",
10+
"\n",
11+
"For this example, you use [Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow), but you can use the same technique for other scripts or repositories. For example, [TensorFlow Model Zoo](https://github.com/tensorflow/models) and [TensorFlow benchmark scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)."
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"### Get the data\n",
19+
"For training data, use plain text versions of Sherlock Holmes stories."
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": null,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"import os\n",
29+
"data_dir = os.path.join(os.getcwd(), 'sherlock')\n",
30+
"\n",
31+
"os.makedirs(data_dir, exist_ok=True)\n",
32+
"\n",
33+
"!wget https://sherlock-holm.es/stories/plain-text/cnus.txt --force-directories --output-document=sherlock/input.txt"
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"metadata": {},
39+
"source": [
40+
"## Preparing the training script\n",
41+
"\n",
42+
"Let's start by cloning the repository that contains the example:"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": null,
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"!git clone https://github.com/sherjilozair/char-rnn-tensorflow"
52+
]
53+
},
54+
{
55+
"cell_type": "markdown",
56+
"metadata": {},
57+
"source": [
58+
"This repository includes a [README.md](https://github.com/sherjilozair/char-rnn-tensorflow/blob/master/README.md#basic-usage) with an overview of the project, requirements, and basic usage:\n",
59+
"\n",
60+
"> #### **Basic Usage**\n",
61+
"> _To train with default parameters on the tinyshakespeare corpus, run **python train.py**. \n",
62+
"To access all the parameters use **python train.py --help.**_\n",
63+
"\n",
64+
"[train.py](https://github.com/sherjilozair/char-rnn-tensorflow/blob/master/train.py#L11) uses the Python [argparse](https://docs.python.org/3/library/argparse.html) library and requires the following arguments:\n",
65+
"\n",
66+
"```python\n",
67+
"parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
68+
"# Data and model checkpoints directories\n",
69+
"parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare', help='data directory containing input.txt with training examples')\n",
70+
"parser.add_argument('--save_dir', type=str, default='save', help='directory to store checkpointed models')\n",
71+
"...\n",
72+
"args = parser.parse_args()\n",
73+
"\n",
74+
"```\n",
75+
"When SageMaker training finishes, it deletes all data generated inside the container with exception of the directory _/opt/ml/model_. To ensure that model data is not lost during training, training scripts are invoked in SageMaker with an additional argument **--model_dir**, that needs to be handle by the training script. We need to replace the argument **--save_dir** with the required argument **--model_dir**: "
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"# this command will replace data_dir by model_dir in the training script\n",
85+
"!sed -i 's/save_dir/model_dir/g' char-rnn-tensorflow/train.py"
86+
]
87+
},
88+
{
89+
"cell_type": "markdown",
90+
"metadata": {},
91+
"source": [
92+
"Now, the training script can executed as follow in the container:\n",
93+
"\n",
94+
"> ```bash\n",
95+
"python train.py --num-epochs 1 --data_dir /opt/ml/input/data/training --model_dir /opt/ml/model\n",
96+
"```"
97+
]
98+
},
99+
{
100+
"cell_type": "markdown",
101+
"metadata": {},
102+
"source": [
103+
"## Test locally using SageMaker Python SDK TensorFlow Estimator"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"metadata": {},
109+
"source": [
110+
"You can use the SageMaker Python SDK [TensorFlow](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/README.rst#training-with-tensorflow) estimator to easily train locally and in SageMaker. To train locally, you set the instance type to [local](https://github.com/aws/sagemaker-python-sdk#local-mode) as follow:"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"metadata": {},
117+
"outputs": [],
118+
"source": [
119+
"import os\n",
120+
"\n",
121+
"import sagemaker\n",
122+
"from sagemaker.tensorflow import TensorFlow\n",
123+
"\n",
124+
"# sets the script arguments --num_epochs and --data_dir\n",
125+
"hyperparameters = {'num_epochs': 1, \n",
126+
" 'data_dir': '/opt/ml/input/data/training'}\n",
127+
"\n",
128+
"estimator = TensorFlow(entry_point='train.py',\n",
129+
" source_dir='char-rnn-tensorflow',\n",
130+
" train_instance_type='local', # Run in local mode\n",
131+
" train_instance_count=1,\n",
132+
" hyperparameters=hyperparameters,\n",
133+
" role=sagemaker.get_execution_role(), # Passes to the container the AWS role that you are using on this notebook\n",
134+
" framework_version='1.11.0', # Uses TensorFlow 1.11\n",
135+
" py_version='py3',\n",
136+
" script_mode=True)\n",
137+
" \n",
138+
"\n",
139+
"estimator.fit({'training': f'file://{data_dir}'}) # Starts training and creates a data channel named training with the contents\n",
140+
"# data_dir in the folder /opt/ml/input/data/training"
141+
]
142+
},
143+
{
144+
"cell_type": "markdown",
145+
"metadata": {},
146+
"source": [
147+
"## How Script Mode executes the script in the container\n",
148+
"\n",
149+
"The above cell downloads SageMaker TensorFlow container with TensorFlow Python 3, CPU version, locally and simulates a SageMaker training job. \n",
150+
"When training starts, the SageMaker TensorFlow executes **train.py**, passing **hyperparameters** and **model_dir** as script arguments. The example above is executed as follows:\n",
151+
"```bash\n",
152+
"python -m train --num-epochs 1 --data_dir /opt/ml/input/data/training --model_dir /opt/ml/model\n",
153+
"```\n",
154+
"\n",
155+
"Let's explain the values of **--data_dir** and **--model_dir** with more details:\n",
156+
"\n",
157+
"- **/opt/ml/input/data/training** is the directory inside the container where the training data is downloaded. The data is downloaded to this folder because **training** is the channel name defined in ```estimator.fit({'training': inputs})```. See [training data](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-trainingdata) for more information. \n",
158+
"\n",
159+
"- **/opt/ml/model** use this directory to save models, checkpoints, or any other data. Any data saved in this folder is saved in the S3 bucket defined for training. See [model data](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-envvariables) for more information.\n",
160+
"\n",
161+
"### Reading additional information from the container\n",
162+
"\n",
163+
"Often, a user script needs additional information from the container that is not available in ```hyperparameters```.\n",
164+
"SageMaker containers write this information as **environment variables** that are available inside the script.\n",
165+
"\n",
166+
"For example, the example above can read information about the **training** channel provided in the training job request by adding the environment variable `SM_CHANNEL_TRAINING` as the default value for the `--data_dir` argument:\n",
167+
"\n",
168+
"```python\n",
169+
"if __name__ == '__main__':\n",
170+
" parser = argparse.ArgumentParser()\n",
171+
" # reads input channels training and testing from the environment variables\n",
172+
" parser.add_argument('--data_dir', type=str, default=os.environ['SM_CHANNEL_TRAINING'])\n",
173+
"```\n",
174+
"\n",
175+
"Script mode displays the list of available environment variables in the training logs. You can find the [entire list here](https://github.com/aws/sagemaker-containers/blob/master/README.md#environment-variables-full-specification)."
176+
]
177+
},
178+
{
179+
"cell_type": "markdown",
180+
"metadata": {},
181+
"source": [
182+
"# Training in SageMaker"
183+
]
184+
},
185+
{
186+
"cell_type": "markdown",
187+
"metadata": {},
188+
"source": [
189+
"After you test the training job locally, upload the dataset to an S3 bucket so SageMaker can access the data during training:"
190+
]
191+
},
192+
{
193+
"cell_type": "code",
194+
"execution_count": null,
195+
"metadata": {},
196+
"outputs": [],
197+
"source": [
198+
"import sagemaker\n",
199+
"\n",
200+
"inputs = sagemaker.Session().upload_data(path='sherlock', key_prefix='datasets/sherlock')"
201+
]
202+
},
203+
{
204+
"cell_type": "markdown",
205+
"metadata": {},
206+
"source": [
207+
"The returned variable inputs above is a string with a s3 location which SageMaker Tranining has permissions\n",
208+
"to read data from. **It has education purposes, requiring\n",
209+
" more robust solutions for larger datasets:**"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": null,
215+
"metadata": {},
216+
"outputs": [],
217+
"source": [
218+
"inputs"
219+
]
220+
},
221+
{
222+
"cell_type": "markdown",
223+
"metadata": {},
224+
"source": [
225+
"To train in SageMaker:\n",
226+
"- change the estimator argument **train_instance_type** to any SageMaker ml instance available for training.\n",
227+
"- set the **training** channel to a S3 location.\n",
228+
"\n",
229+
"For example:"
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": null,
235+
"metadata": {},
236+
"outputs": [],
237+
"source": [
238+
"estimator = TensorFlow(entry_point='train.py',\n",
239+
" source_dir='char-rnn-tensorflow',\n",
240+
" train_instance_type='ml.c4.xlarge', # Executes training in a ml.c4.xlarge instance\n",
241+
" train_instance_count=1,\n",
242+
" hyperparameters=hyperparameters,\n",
243+
" role=sagemaker.get_execution_role(),\n",
244+
" framework_version='1.11.0',\n",
245+
" py_version='py3',\n",
246+
" script_mode=True)\n",
247+
" \n",
248+
"\n",
249+
"estimator.fit({'training': inputs})"
250+
]
251+
}
252+
],
253+
"metadata": {
254+
"kernelspec": {
255+
"display_name": "conda_tensorflow_p36",
256+
"language": "python",
257+
"name": "conda_tensorflow_p36"
258+
},
259+
"language_info": {
260+
"codemirror_mode": {
261+
"name": "ipython",
262+
"version": 3
263+
},
264+
"file_extension": ".py",
265+
"mimetype": "text/x-python",
266+
"name": "python",
267+
"nbconvert_exporter": "python",
268+
"pygments_lexer": "ipython3",
269+
"version": "3.6.5"
270+
}
271+
},
272+
"nbformat": 4,
273+
"nbformat_minor": 2
274+
}

0 commit comments

Comments
 (0)