Skip to content

Commit b6974bd

Browse files
Merge pull request #3 from HowardRiddiough/update-notebook
Improved 'spark_predict_pudf'
2 parents be0ce52 + 30ddbd3 commit b6974bd

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

deploying-python-ml-in-pyspark.ipynb

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"import pandas as pd\n",
2929
"from sklearn.ensemble import RandomForestRegressor\n",
3030
"from sklearn.pipeline import Pipeline\n",
31-
"from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder\n",
31+
"from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder, LabelEncoder\n",
3232
"from sklearn.compose import ColumnTransformer\n",
3333
"import pyspark.sql\n",
3434
"from pyspark.sql import SparkSession\n",
@@ -136,21 +136,23 @@
136136
"metadata": {},
137137
"outputs": [],
138138
"source": [
139-
"def spark_predict(model, *cols) -> pyspark.sql.column:\n",
139+
"def spark_predict(model, cols) -> pyspark.sql.column:\n",
140140
" \"\"\"This function deploys python ml in PySpark using the `predict` method of the `model` parameter.\n",
141141
" \n",
142142
" Args:\n",
143143
" model: python ml model with sklearn API\n",
144-
" *cols (list-like): Features used for predictions, required to be present as columns in the spark \n",
144+
" cols (list-like): Features used for predictions, required to be present as columns in the spark \n",
145145
" DataFrame used to make predictions.\n",
146146
" \"\"\"\n",
147147
" @sf.pandas_udf(returnType=DoubleType())\n",
148148
" def predict_pandas_udf(*cols):\n",
149-
" # cols will be a tuple of pandas.Series here.\n",
150149
" X = pd.concat(cols, axis=1)\n",
151150
" return pd.Series(model.predict(X))\n",
152151
" \n",
153-
" return predict_pandas_udf(*cols)"
152+
" return predict_pandas_udf(*cols)\n",
153+
"\n",
154+
" \n",
155+
" "
154156
]
155157
},
156158
{
@@ -184,7 +186,7 @@
184186
"(\n",
185187
" ddf\n",
186188
" .select(NUMERICAL_FEATURES + [TARGET])\n",
187-
" .withColumn(\"prediction\", spark_predict(rf, *NUMERICAL_FEATURES).alias(\"prediction\"))\n",
189+
" .withColumn(\"prediction\", spark_predict(rf, NUMERICAL_FEATURES).alias(\"prediction\"))\n",
188190
" .show(5)\n",
189191
")"
190192
]
@@ -230,7 +232,7 @@
230232
"(\n",
231233
" ddf\n",
232234
" .select(NUMERICAL_FEATURES + [TARGET])\n",
233-
" .withColumn(\"pipe_predict\", spark_predict(pipe, *NUMERICAL_FEATURES).alias(\"prediction\")).show(5)\n",
235+
" .withColumn(\"pipe_predict\", spark_predict(pipe, NUMERICAL_FEATURES).alias(\"prediction\")).show(5)\n",
234236
")"
235237
]
236238
},
@@ -285,7 +287,7 @@
285287
"(\n",
286288
" ddf\n",
287289
" .select(ALL_FEATURES + [TARGET])\n",
288-
" .withColumn(\"pipe_predict\", spark_predict(preprocessor_pipe, *ALL_FEATURES).alias(\"prediction\"))\n",
290+
" .withColumn(\"pipe_predict\", spark_predict(preprocessor_pipe, ALL_FEATURES).alias(\"prediction\"))\n",
289291
" .show(5)\n",
290292
")"
291293
]

0 commit comments

Comments
 (0)