|
28 | 28 | "import pandas as pd\n", |
29 | 29 | "from sklearn.ensemble import RandomForestRegressor\n", |
30 | 30 | "from sklearn.pipeline import Pipeline\n", |
31 | | - "from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder\n", |
| 31 | + "from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder, LabelEncoder\n", |
32 | 32 | "from sklearn.compose import ColumnTransformer\n", |
33 | 33 | "import pyspark.sql\n", |
34 | 34 | "from pyspark.sql import SparkSession\n", |
|
136 | 136 | "metadata": {}, |
137 | 137 | "outputs": [], |
138 | 138 | "source": [ |
139 | | - "def spark_predict(model, *cols) -> pyspark.sql.column:\n", |
| 139 | + "def spark_predict(model, cols) -> pyspark.sql.column:\n", |
140 | 140 | " \"\"\"This function deploys python ml in PySpark using the `predict` method of the `model` parameter.\n", |
141 | 141 | " \n", |
142 | 142 | " Args:\n", |
143 | 143 | " model: python ml model with sklearn API\n", |
144 | | - " *cols (list-like): Features used for predictions, required to be present as columns in the spark \n", |
| 144 | + " cols (list-like): Features used for predictions, required to be present as columns in the spark \n", |
145 | 145 | " DataFrame used to make predictions.\n", |
146 | 146 | " \"\"\"\n", |
147 | 147 | " @sf.pandas_udf(returnType=DoubleType())\n", |
148 | 148 | " def predict_pandas_udf(*cols):\n", |
149 | | - " # cols will be a tuple of pandas.Series here.\n", |
150 | 149 | " X = pd.concat(cols, axis=1)\n", |
151 | 150 | " return pd.Series(model.predict(X))\n", |
152 | 151 | " \n", |
153 | | - " return predict_pandas_udf(*cols)" |
| 152 | + " return predict_pandas_udf(*cols)\n", |
| 153 | + "\n", |
| 154 | + " \n", |
| 155 | + " " |
154 | 156 | ] |
155 | 157 | }, |
156 | 158 | { |
|
184 | 186 | "(\n", |
185 | 187 | " ddf\n", |
186 | 188 | " .select(NUMERICAL_FEATURES + [TARGET])\n", |
187 | | - " .withColumn(\"prediction\", spark_predict(rf, *NUMERICAL_FEATURES).alias(\"prediction\"))\n", |
| 189 | + " .withColumn(\"prediction\", spark_predict(rf, NUMERICAL_FEATURES).alias(\"prediction\"))\n", |
188 | 190 | " .show(5)\n", |
189 | 191 | ")" |
190 | 192 | ] |
|
230 | 232 | "(\n", |
231 | 233 | " ddf\n", |
232 | 234 | " .select(NUMERICAL_FEATURES + [TARGET])\n", |
233 | | - " .withColumn(\"pipe_predict\", spark_predict(pipe, *NUMERICAL_FEATURES).alias(\"prediction\")).show(5)\n", |
| 235 | + " .withColumn(\"pipe_predict\", spark_predict(pipe, NUMERICAL_FEATURES).alias(\"prediction\")).show(5)\n", |
234 | 236 | ")" |
235 | 237 | ] |
236 | 238 | }, |
|
285 | 287 | "(\n", |
286 | 288 | " ddf\n", |
287 | 289 | " .select(ALL_FEATURES + [TARGET])\n", |
288 | | - " .withColumn(\"pipe_predict\", spark_predict(preprocessor_pipe, *ALL_FEATURES).alias(\"prediction\"))\n", |
| 290 | + " .withColumn(\"pipe_predict\", spark_predict(preprocessor_pipe, ALL_FEATURES).alias(\"prediction\"))\n", |
289 | 291 | " .show(5)\n", |
290 | 292 | ")" |
291 | 293 | ] |
|
0 commit comments