Skip to content

Commit b5735b2

Browse files
authored
large refactor for latest autogluon version (aws#1289)
1 parent c70e979 commit b5735b2

File tree

3 files changed

+172
-246
lines changed

3 files changed

+172
-246
lines changed

advanced_functionality/autogluon-tabular/AutoGluon_Tabular_SageMaker.ipynb

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"import sagemaker\n",
5050
"from time import sleep\n",
5151
"from collections import Counter\n",
52+
"import numpy as np\n",
5253
"import pandas as pd\n",
5354
"from sagemaker import get_execution_role, local, Model, utils, fw_utils, s3\n",
5455
"from sagemaker.estimator import Estimator\n",
@@ -106,12 +107,10 @@
106107
"source": [
107108
"if not os.path.exists('package'):\n",
108109
" !pip install PrettyTable -t package\n",
109-
" !pip install bokeh -t package\n",
110-
" !pip install --pre autogluon -t package\n",
111-
" !pip install numpy==1.16.1 -t package \n",
112110
" !pip install --upgrade boto3 -t package\n",
113111
" !pip install bokeh -t package\n",
114-
" !pip install --upgrade matplotlib -t package"
112+
" !pip install --upgrade matplotlib -t package\n",
113+
" !pip install autogluon -t package"
115114
]
116115
},
117116
{
@@ -255,80 +254,83 @@
255254
"Collapsed": "false"
256255
},
257256
"source": [
258-
"## Train"
259-
]
260-
},
261-
{
262-
"cell_type": "markdown",
263-
"metadata": {
264-
"Collapsed": "false"
265-
},
266-
"source": [
267-
"The minimum requirement for hyperparameters is a target label."
268-
]
269-
},
270-
{
271-
"cell_type": "code",
272-
"execution_count": null,
273-
"metadata": {
274-
"Collapsed": "false"
275-
},
276-
"outputs": [],
277-
"source": [
278-
"hyperparameters = {'label': 'y'}"
279-
]
280-
},
281-
{
282-
"cell_type": "markdown",
283-
"metadata": {
284-
"Collapsed": "false"
285-
},
286-
"source": [
287-
"##### (Optional) hyperparameters can be passed to the `autogluon.task.TabularPrediction.fit` function. \n",
257+
"## Hyperparameter Selection\n",
288258
"\n",
289-
"Below shows AutoGluon hyperparameters from the example [Predicting Columns in a Table - In Depth](https://autogluon.mxnet.io/tutorials/tabular_prediction/tabular-indepth.html#model-ensembling-with-stacking-bagging). Please see [fit parameters](https://autogluon.mxnet.io/api/autogluon.task.html?highlight=eval_metric#autogluon.task.TabularPrediction.fit) for further information.\n",
259+
"The minimum required settings for training is just a target label, `fit_args['label']`.\n",
290260
"\n",
261+
"Additional optional hyperparameters can be passed to the `autogluon.task.TabularPrediction.fit` function via `fit_args`.\n",
291262
"\n",
292-
"Here's a more in depth example from the above tutorial that shows how to provide hyperparameter ranges and additional settings:\n",
263+
"Below shows a more in depth example of AutoGluon-Tabular hyperparameters from the example [Predicting Columns in a Table - In Depth](https://autogluon.mxnet.io/tutorials/tabular_prediction/tabular-indepth.html#model-ensembling-with-stacking-bagging). Please see [fit parameters](https://autogluon.mxnet.io/api/autogluon.task.html?highlight=eval_metric#autogluon.task.TabularPrediction.fit) for further information. Note that in order for hyperparameter ranges to work in SageMaker, values passed to the `fit_args['hyperparameters']` must be represented as strings.\n",
293264
"\n",
294265
"```python\n",
295266
"nn_options = {\n",
296-
" 'num_epochs': '10',\n",
267+
" 'num_epochs': \"10\",\n",
297268
" 'learning_rate': \"ag.space.Real(1e-4, 1e-2, default=5e-4, log=True)\",\n",
298269
" 'activation': \"ag.space.Categorical('relu', 'softrelu', 'tanh')\",\n",
299270
" 'layers': \"ag.space.Categorical([100],[1000],[200,100],[300,200,100])\",\n",
300271
" 'dropout_prob': \"ag.space.Real(0.0, 0.5, default=0.1)\"\n",
301272
"}\n",
302273
"\n",
303274
"gbm_options = {\n",
304-
" 'num_boost_round': '100',\n",
275+
" 'num_boost_round': \"100\",\n",
305276
" 'num_leaves': \"ag.space.Int(lower=26, upper=66, default=36)\"\n",
306277
"}\n",
307278
"\n",
308279
"model_hps = {'NN': nn_options, 'GBM': gbm_options} \n",
309280
"\n",
281+
"fit_args = {\n",
282+
" 'label': 'y',\n",
283+
" 'presets': ['best_quality', 'optimize_for_deployment'],\n",
284+
" 'time_limits': 60*10,\n",
285+
" 'hyperparameters': model_hps,\n",
286+
" 'hyperparameter_tune': True,\n",
287+
" 'search_strategy': 'skopt'\n",
288+
"}\n",
289+
"\n",
310290
"hyperparameters = {\n",
311-
" 'label': 'y',\n",
312-
" 'time_limits': 2*60,\n",
313-
" 'hyperparameters': model_hps,\n",
314-
" 'auto_stack': False, \n",
315-
" 'hyperparameter_tune': True,\n",
316-
" 'search_strategy': 'skopt'\n",
291+
" 'fit_args': fit_args,\n",
292+
" 'feature_importance': True\n",
317293
"}\n",
318294
"```\n",
319-
"**Note:** Your hyperparameter choices may affect the size of the model package, which could result in additional time taken to upload your model and complete training.\n",
295+
"**Note:** Your hyperparameter choices may affect the size of the model package, which could result in additional time taken to upload your model and complete training. Including `'optimize_for_deployment'` in the list of `fit_args['presets']` is recommended to greatly reduce upload times.\n",
320296
"\n",
321297
"<br>"
322298
]
323299
},
300+
{
301+
"cell_type": "code",
302+
"execution_count": null,
303+
"metadata": {
304+
"Collapsed": "false"
305+
},
306+
"outputs": [],
307+
"source": [
308+
"# Define required label and optional additional parameters\n",
309+
"fit_args = {\n",
310+
" 'label': 'y',\n",
311+
" # Adding 'best_quality' to presets list will result in better performance (but longer runtime)\n",
312+
" 'presets': ['optimize_for_deployment'],\n",
313+
"}\n",
314+
"\n",
315+
"# Pass fit_args to SageMaker estimator hyperparameters\n",
316+
"hyperparameters = {\n",
317+
" 'fit_args': fit_args,\n",
318+
" 'feature_importance': True\n",
319+
"}"
320+
]
321+
},
324322
{
325323
"cell_type": "markdown",
326324
"metadata": {
327325
"Collapsed": "false"
328326
},
329327
"source": [
330-
"For local training set `train_instance_type` to `local` . \n",
331-
"For non-local training the recommended instance type is `ml.m5.2xlarge` ."
328+
"## Train\n",
329+
"\n",
330+
"For local training set `train_instance_type` to `local` . \n",
331+
"For non-local training the recommended instance type is `ml.m5.2xlarge`. \n",
332+
"\n",
333+
"**Note:** Depending on how many underlying models are trained, `train_volume_size` may need to be increased so that they all fit on disk."
332334
]
333335
},
334336
{
@@ -350,9 +352,13 @@
350352
" role=role,\n",
351353
" train_instance_count=1,\n",
352354
" train_instance_type=instance_type,\n",
353-
" hyperparameters=hyperparameters)\n",
355+
" hyperparameters=hyperparameters,\n",
356+
" train_volume_size=100)\n",
354357
"\n",
355-
"estimator.fit(train_s3_path)"
358+
"# Set inputs. Test data is optional, but requires a label column.\n",
359+
"inputs = {'training': train_s3_path, 'testing': test_s3_path}\n",
360+
"\n",
361+
"estimator.fit(inputs)"
356362
]
357363
},
358364
{
@@ -516,10 +522,10 @@
516522
},
517523
"outputs": [],
518524
"source": [
519-
"results = predictor.predict(X_test.to_csv())\n",
525+
"results = predictor.predict(X_test.to_csv()).splitlines()\n",
520526
"\n",
521527
"# Check output\n",
522-
"print(Counter(results.splitlines()))"
528+
"print(Counter(results))"
523529
]
524530
},
525531
{
@@ -540,10 +546,10 @@
540546
},
541547
"outputs": [],
542548
"source": [
543-
"results = predictor.predict(test.to_csv())\n",
549+
"results = predictor.predict(test.to_csv()).splitlines()\n",
544550
"\n",
545551
"# Check output\n",
546-
"sleep(0.1); print(Counter(results.splitlines()))"
552+
"print(Counter(results))"
547553
]
548554
},
549555
{
@@ -552,7 +558,7 @@
552558
"Collapsed": "false"
553559
},
554560
"source": [
555-
"##### Check that performance metrics match evaluation printed to endpoint logs as expected"
561+
"##### Check that classification performance metrics match evaluation printed to endpoint logs as expected"
556562
]
557563
},
558564
{
@@ -563,8 +569,7 @@
563569
},
564570
"outputs": [],
565571
"source": [
566-
"import numpy as np\n",
567-
"y_results = np.array(results.splitlines())\n",
572+
"y_results = np.array(results)\n",
568573
"\n",
569574
"print(\"accuracy: {}\".format(accuracy_score(y_true=y_test, y_pred=y_results)))\n",
570575
"print(classification_report(y_true=y_test, y_pred=y_results, digits=6))"
@@ -593,7 +598,7 @@
593598
],
594599
"metadata": {
595600
"kernelspec": {
596-
"display_name": "conda_mxnet_p36",
601+
"display_name": "Environment (conda_mxnet_p36)",
597602
"language": "python",
598603
"name": "conda_mxnet_p36"
599604
},

advanced_functionality/autogluon-tabular/container-training/inference.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,37 @@
33
import argparse
44
import logging
55
import warnings
6-
import os
6+
import time
77
import json
88
import subprocess
99

10-
warnings.filterwarnings("ignore",category=FutureWarning)
10+
warnings.filterwarnings('ignore', category=FutureWarning)
1111

1212
sys.path.append(os.path.join(os.path.dirname(__file__), '/opt/ml/code/package'))
1313

14+
import numpy as np
1415
import pandas as pd
1516
import pickle
1617
from io import StringIO
1718
from timeit import default_timer as timer
19+
from itertools import islice
1820
from collections import Counter
1921

2022
with warnings.catch_warnings():
21-
warnings.filterwarnings("ignore",category=DeprecationWarning)
23+
warnings.filterwarnings('ignore', category=DeprecationWarning)
2224
from prettytable import PrettyTable
2325
from autogluon import TabularPrediction as task
2426

25-
def make_str_table(df):
27+
def make_str_table(df):
2628
table = PrettyTable(['index']+list(df.columns))
2729
for row in df.itertuples():
2830
table.add_row(row)
2931
return str(table)
3032

33+
def take(n, iterable):
34+
"Return first n items of the iterable as a list"
35+
return list(islice(iterable, n))
36+
3137
# ------------------------------------------------------------ #
3238
# Hosting methods #
3339
# ------------------------------------------------------------ #
@@ -39,7 +45,7 @@ def model_fn(model_dir):
3945
:return: a model (in this case a Gluon network)
4046
"""
4147
print(f'Loading model from {model_dir} with contents {os.listdir(model_dir)}')
42-
net = task.load(model_dir, verbosity=True)
48+
net = task.load(model_dir, verbosity=True)
4349
return net
4450

4551

@@ -53,37 +59,56 @@ def transform_fn(net, data, input_content_type, output_content_type):
5359
:return: response payload and content type.
5460
"""
5561
start = timer()
56-
62+
5763
# text/csv
5864
if input_content_type == 'text/csv':
59-
65+
6066
# Load dataset
6167
df = pd.read_csv(StringIO(data))
6268
ds = task.Dataset(df=df)
63-
64-
# Predict
65-
predictions = net.predict(ds)
66-
print(f'Prediction counts: {Counter(predictions.tolist())}')
6769

70+
try:
71+
predictions = net.predict(ds)
72+
except:
73+
try:
74+
predictions = net.predict(ds.fillna(0.0))
75+
warnings.warn('Filled NaN\'s with 0.0 in order to predict.')
76+
except Exception as e:
77+
response_body = e
78+
return response_body, output_content_type
79+
80+
# Print prediction counts, limit in case of regression problem
81+
pred_counts = Counter(predictions.tolist())
82+
n_display_items = 30
83+
if len(pred_counts) > n_display_items:
84+
print(f'Top {n_display_items} prediction counts: '
85+
f'{dict(take(n_display_items, pred_counts.items()))}')
86+
else:
87+
print(f'Prediction counts: {pred_counts}')
88+
6889
# Form response
6990
output = StringIO()
7091
pd.DataFrame(predictions).to_csv(output, header=False, index=False)
71-
response_body = output.getvalue()
72-
92+
response_body = output.getvalue()
93+
7394
# If target column passed, evaluate predictions performance
7495
target = net.label_column
7596
if target in ds:
7697
print(f'Label column ({target}) found in input data. '
77-
'Therefore, evaluating prediction performance...')
78-
79-
performance = net.evaluate_predictions(y_true=ds[target], y_pred=predictions,
80-
auxiliary_metrics=True)
81-
print(json.dumps(performance, indent=4))
82-
83-
else:
98+
'Therefore, evaluating prediction performance...')
99+
try:
100+
performance = net.evaluate_predictions(y_true=ds[target],
101+
y_pred=predictions,
102+
auxiliary_metrics=True)
103+
print(json.dumps(performance, indent=4))
104+
time.sleep(0.1)
105+
except Exception as e:
106+
# Print exceptions on evaluate, continue to return predictions
107+
print(f'Exception: {e}')
108+
else:
84109
raise NotImplementedError("content_type must be 'text/csv'")
85110

86111
elapsed_time = round(timer()-start,3)
87112
print(f'Elapsed time: {round(timer()-start,3)} seconds')
88113

89-
return response_body, output_content_type
114+
return response_body, output_content_type

0 commit comments

Comments
 (0)