Skip to content
21 changes: 12 additions & 9 deletions site/en/tutorials/keras/regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,10 @@
"test_features = test_dataset.copy()\n",
"\n",
"train_labels = train_features.pop('MPG')\n",
"test_labels = test_features.pop('MPG')"
"test_labels = test_features.pop('MPG')",
"\n",
"train_features = tf.convert_to_tensor(train_features, dtype=tf.float32)\n",
"test_features = tf.convert_to_tensor(test_features, dtype=tf.float32)"
]
},
{
Expand Down Expand Up @@ -545,7 +548,7 @@
},
"outputs": [],
"source": [
"horsepower = np.array(train_features['Horsepower'])\n",
"horsepower = np.array(train_features[:, 2])\n",
"\n",
"horsepower_normalizer = layers.Normalization(input_shape=[1,], axis=None)\n",
"horsepower_normalizer.adapt(horsepower)"
Expand Down Expand Up @@ -639,7 +642,7 @@
"source": [
"%%time\n",
"history = horsepower_model.fit(\n",
" train_features['Horsepower'],\n",
" train_features[:, 2],\n",
" train_labels,\n",
" epochs=100,\n",
" # Suppress logging.\n",
Expand Down Expand Up @@ -719,7 +722,7 @@
"test_results = {}\n",
"\n",
"test_results['horsepower_model'] = horsepower_model.evaluate(\n",
" test_features['Horsepower'],\n",
" test_features[:, 2],\n",
" test_labels, verbose=0)"
]
},
Expand Down Expand Up @@ -753,7 +756,7 @@
"outputs": [],
"source": [
"def plot_horsepower(x, y):\n",
" plt.scatter(train_features['Horsepower'], train_labels, label='Data')\n",
" plt.scatter(train_features[:, 2], train_labels, label='Data')\n",
" plt.plot(x, y, color='k', label='Predictions')\n",
" plt.xlabel('Horsepower')\n",
" plt.ylabel('MPG')\n",
Expand Down Expand Up @@ -1053,7 +1056,7 @@
"source": [
"%%time\n",
"history = dnn_horsepower_model.fit(\n",
" train_features['Horsepower'],\n",
" train_features[:, 2],\n",
" train_labels,\n",
" validation_split=0.2,\n",
" verbose=0, epochs=100)"
Expand Down Expand Up @@ -1129,7 +1132,7 @@
"outputs": [],
"source": [
"test_results['dnn_horsepower_model'] = dnn_horsepower_model.evaluate(\n",
" test_features['Horsepower'], test_labels,\n",
" test_features[:, 2], test_labels,\n",
" verbose=0)"
]
},
Expand Down Expand Up @@ -1321,7 +1324,7 @@
},
"outputs": [],
"source": [
"dnn_model.save('dnn_model.keras')"
"dnn_model.save('dnn_model.tf', save_format='tf')"
]
},
{
Expand All @@ -1341,7 +1344,7 @@
},
"outputs": [],
"source": [
"reloaded = tf.keras.models.load_model('dnn_model.keras')\n",
"reloaded = tf.keras.models.load_model('dnn_model.tf')\n",
"\n",
"test_results['reloaded'] = reloaded.evaluate(\n",
" test_features, test_labels, verbose=0)"
Expand Down