Skip to content

Commit cc687f5

Browse files
Aloha106j3ffreyjohn
authored andcommitted
Fix scikit bring-your-own local testing with Batch IO joining. (aws#771)
1 parent d4e4364 commit cc687f5

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

advanced_functionality/scikit_bring_your_own/container/decision_trees/predictor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ def transformation():
7272

7373
print('Invoked with {} records'.format(data.shape[0]))
7474

75-
# Drop first column, since sample notebook uses training data to show case predictions
76-
data.drop(data.columns[[0]],axis=1,inplace=True)
77-
7875
# Do the prediction
7976
predictions = ScoringService.predict(data)
8077

advanced_functionality/scikit_bring_your_own/scikit_bring_your_own.ipynb

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@
292292
"\n",
293293
"The scripts are:\n",
294294
"\n",
295-
"* `train_local.sh`: Run this with the name of the image and it will run training on the local tree. You'll want to modify the directory `test_dir/input/data/...` to be set up with the correct channels and data for your algorithm. Also, you'll want to modify the file `input/config/hyperparameters.json` to have the hyperparameter settings that you want to test (as strings).\n",
296-
"* `serve_local.sh`: Run this with the name of the image once you've trained the model and it should serve the model. It will run and wait for requests. Simply use the keyboard interrupt to stop it.\n",
295+
"* `train_local.sh`: Run this with the name of the image and it will run training on the local tree. For example, you can run `$ ./train_local.sh sagemaker-decision-trees`. It will generate a model under the `/test_dir/model` directory. You'll want to modify the directory `test_dir/input/data/...` to be set up with the correct channels and data for your algorithm. Also, you'll want to modify the file `input/config/hyperparameters.json` to have the hyperparameter settings that you want to test (as strings).\n",
296+
"* `serve_local.sh`: Run this with the name of the image once you've trained the model and it should serve the model. For example, you can run `$ ./serve_local.sh sagemaker-decision-trees`. It will run and wait for requests. Simply use the keyboard interrupt to stop it.\n",
297297
"* `predict.sh`: Run this with the name of a payload file and (optionally) the HTTP content type you want. The content type will default to `text/csv`. For example, you can run `$ ./predict.sh payload.csv text/csv`.\n",
298298
"\n",
299299
"The directories as shipped are set up to test the decision trees sample algorithm presented here."
@@ -459,7 +459,26 @@
459459
"outputs": [],
460460
"source": [
461461
"shape=pd.read_csv(\"data/iris.csv\", header=None)\n",
462-
"\n",
462+
"shape.sample(3)"
463+
]
464+
},
465+
{
466+
"cell_type": "code",
467+
"execution_count": null,
468+
"metadata": {},
469+
"outputs": [],
470+
"source": [
471+
"# drop the label column in the training set\n",
472+
"shape.drop(shape.columns[[0]],axis=1,inplace=True)\n",
473+
"shape.sample(3)"
474+
]
475+
},
476+
{
477+
"cell_type": "code",
478+
"execution_count": null,
479+
"metadata": {},
480+
"outputs": [],
481+
"source": [
463482
"import itertools\n",
464483
"\n",
465484
"a = [50*i for i in range(3)]\n",
@@ -533,7 +552,9 @@
533552
"\n",
534553
"transformer = tree.transformer(instance_count=1,\n",
535554
" instance_type='ml.m4.xlarge',\n",
536-
" output_path=output_path)"
555+
" output_path=output_path,\n",
556+
" assemble_with='Line',\n",
557+
" accept='text/csv')"
537558
]
538559
},
539560
{
@@ -544,7 +565,8 @@
544565
"\n",
545566
"* The __data_location__ which is the location of input data\n",
546567
"* The __content_type__ which is the content type set when making HTTP request to container to get prediction\n",
547-
"* The __split_type__ which is the delimiter used for splitting input data "
568+
"* The __split_type__ which is the delimiter used for splitting input data \n",
569+
"* The __input_filter__ which indicates the first column (ID) of the input will be dropped before making HTTP request to container"
548570
]
549571
},
550572
{
@@ -553,7 +575,7 @@
553575
"metadata": {},
554576
"outputs": [],
555577
"source": [
556-
"transformer.transform(data_location, content_type='text/csv', split_type='Line')\n",
578+
"transformer.transform(data_location, content_type='text/csv', split_type='Line', input_filter='$[1:]')\n",
557579
"transformer.wait()"
558580
]
559581
},

0 commit comments

Comments
 (0)