You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"Starting by the TensorFlow's framework version 1.11, you can use the SageMaker TensorFlow Container to train any TensorFlow script. \n",
10
+
"\n",
11
+
"For this example, you use [Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow), but you can use the same technique for other scripts or repositories. For example, [TensorFlow Model Zoo](https://github.com/tensorflow/models) and [TensorFlow benchmark scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)."
12
+
]
13
+
},
14
+
{
15
+
"cell_type": "markdown",
16
+
"metadata": {},
17
+
"source": [
18
+
"### Get the data\n",
19
+
"For training data, use plain text versions of Sherlock Holmes stories."
"This repository includes a [README.md](https://github.com/sherjilozair/char-rnn-tensorflow/blob/master/README.md#basic-usage) with an overview of the project, requirements, and basic usage:\n",
59
+
"\n",
60
+
"> #### **Basic Usage**\n",
61
+
"> _To train with default parameters on the tinyshakespeare corpus, run **python train.py**. \n",
62
+
"To access all the parameters use **python train.py --help.**_\n",
63
+
"\n",
64
+
"[train.py](https://github.com/sherjilozair/char-rnn-tensorflow/blob/master/train.py#L11) uses the Python [argparse](https://docs.python.org/3/library/argparse.html) library and requires the following arguments:\n",
"parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare', help='data directory containing input.txt with training examples')\n",
70
+
"parser.add_argument('--save_dir', type=str, default='save', help='directory to store checkpointed models')\n",
71
+
"...\n",
72
+
"args = parser.parse_args()\n",
73
+
"\n",
74
+
"```\n",
75
+
"When SageMaker training finishes, it deletes all data generated inside the container with exception of the directory _/opt/ml/model_. To ensure that model data is not lost during training, training scripts are invoked in SageMaker with an additional argument **--model_dir**, that needs to be handle by the training script. We need to replace the argument **--save_dir** with the required argument **--model_dir**: "
76
+
]
77
+
},
78
+
{
79
+
"cell_type": "code",
80
+
"execution_count": null,
81
+
"metadata": {},
82
+
"outputs": [],
83
+
"source": [
84
+
"# this command will replace data_dir by model_dir in the training script\n",
"## Test locally using SageMaker Python SDK TensorFlow Estimator"
104
+
]
105
+
},
106
+
{
107
+
"cell_type": "markdown",
108
+
"metadata": {},
109
+
"source": [
110
+
"You can use the SageMaker Python SDK [TensorFlow](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/README.rst#training-with-tensorflow) estimator to easily train locally and in SageMaker. To train locally, you set the instance type to [local](https://github.com/aws/sagemaker-python-sdk#local-mode) as follow:"
111
+
]
112
+
},
113
+
{
114
+
"cell_type": "code",
115
+
"execution_count": null,
116
+
"metadata": {},
117
+
"outputs": [],
118
+
"source": [
119
+
"import os\n",
120
+
"\n",
121
+
"import sagemaker\n",
122
+
"from sagemaker.tensorflow import TensorFlow\n",
123
+
"\n",
124
+
"# sets the script arguments --num_epochs and --data_dir\n",
"estimator.fit({'training': f'file://{data_dir}'}) # Starts training and creates a data channel named training with the contents\n",
140
+
"# data_dir in the folder /opt/ml/input/data/training"
141
+
]
142
+
},
143
+
{
144
+
"cell_type": "markdown",
145
+
"metadata": {},
146
+
"source": [
147
+
"## How Script Mode executes the script in the container\n",
148
+
"\n",
149
+
"The above cell downloads SageMaker TensorFlow container with TensorFlow Python 3, CPU version, locally and simulates a SageMaker training job. \n",
150
+
"When training starts, the SageMaker TensorFlow executes **train.py**, passing **hyperparameters** and **model_dir** as script arguments. The example above is executed as follows:\n",
"Let's explain the values of **--data_dir** and **--model_dir** with more details:\n",
156
+
"\n",
157
+
"- **/opt/ml/input/data/training** is the directory inside the container where the training data is downloaded. The data is downloaded to this folder because **training** is the channel name defined in ```estimator.fit({'training': inputs})```. See [training data](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-running-container-trainingdata) for more information. \n",
158
+
"\n",
159
+
"- **/opt/ml/model** use this directory to save models, checkpoints, or any other data. Any data saved in this folder is saved in the S3 bucket defined for training. See [model data](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html#your-algorithms-training-algo-envvariables) for more information.\n",
160
+
"\n",
161
+
"### Reading additional information from the container\n",
162
+
"\n",
163
+
"Often, a user script needs additional information from the container that is not available in ```hyperparameters```.\n",
164
+
"SageMaker containers write this information as **environment variables** that are available inside the script.\n",
165
+
"\n",
166
+
"For example, the example above can read information about the **training** channel provided in the training job request by adding the environment variable `SM_CHANNEL_TRAINING` as the default value for the `--data_dir` argument:\n",
167
+
"\n",
168
+
"```python\n",
169
+
"if __name__ == '__main__':\n",
170
+
" parser = argparse.ArgumentParser()\n",
171
+
" # reads input channels training and testing from the environment variables\n",
"Script mode displays the list of available environment variables in the training logs. You can find the [entire list here](https://github.com/aws/sagemaker-containers/blob/master/README.md#environment-variables-full-specification)."
176
+
]
177
+
},
178
+
{
179
+
"cell_type": "markdown",
180
+
"metadata": {},
181
+
"source": [
182
+
"# Training in SageMaker"
183
+
]
184
+
},
185
+
{
186
+
"cell_type": "markdown",
187
+
"metadata": {},
188
+
"source": [
189
+
"After you test the training job locally, upload the dataset to an S3 bucket so SageMaker can access the data during training:"
0 commit comments