Skip to content

Commit 35941a3

Browse files
apackerlaurenyu
authored andcommitted
Fix local dataset generation in TFRecord example notebook (aws#682)
This commit also makes the dataset generation and model creation cells re-runnable, and adds a wait to the batch transform step.
1 parent b573fe5 commit 35941a3

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

advanced_functionality/working_with_tfrecords/working-with-tfrecords.ipynb

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,13 @@
9090
" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))\n",
9191
"\n",
9292
"# create the tfrecord dataset dir\n",
93-
"os.mkdir(tfrecord_root)\n",
93+
"if not os.path.isdir(tfrecord_root):\n",
94+
" os.mkdir(tfrecord_root)\n",
9495
"\n",
9596
"for input_file, output_file in [(test_csv_file,test_tfrecord_file), (train_csv_file,train_tfrecord_file)]:\n",
9697
" # create the output file\n",
9798
" open(tfrecord_root + output_file, 'a').close()\n",
98-
" with tf.python_io.TFRecordWriter(output_file) as writer:\n",
99+
" with tf.python_io.TFRecordWriter(tfrecord_root + output_file) as writer:\n",
99100
" with open(csv_root + input_file,'r') as f:\n",
100101
" f.readline() # skip first line\n",
101102
" for line in f:\n",
@@ -105,7 +106,7 @@
105106
" 'petal_length': _floatlist_feature(line.split(',')[2]),\n",
106107
" 'petal_width': _floatlist_feature(line.split(',')[3]),\n",
107108
" }\n",
108-
" if file == train_csv_file:\n",
109+
" if f == train_csv_file:\n",
109110
" feature['label'] = _int64list_feature(int(line.split(',')[4].rstrip()))\n",
110111
" example = tf.train.Example(\n",
111112
" features=tf.train.Features(\n",
@@ -266,10 +267,11 @@
266267
"outputs": [],
267268
"source": [
268269
"from sagemaker.tensorflow.serving import Model\n",
270+
"from sagemaker.utils import name_from_base\n",
269271
"\n",
270272
"client = boto3.client('sagemaker')\n",
271273
"\n",
272-
"model_name = 'tfrecord-to-tfserving'\n",
274+
"model_name = name_from_base('tfrecord-to-tfserving')\n",
273275
"\n",
274276
"transform_container = {\n",
275277
" \"Image\": transformer_repository_uri\n",
@@ -324,7 +326,8 @@
324326
" sagemaker_session=sess,\n",
325327
")\n",
326328
"transformer.transform(data = input_data_path,\n",
327-
" split_type = 'TFRecord')"
329+
" split_type = 'TFRecord')\n",
330+
"transformer.wait()"
328331
]
329332
},
330333
{

0 commit comments

Comments
 (0)