Skip to content

Commit 868a58b

Browse files
raghavrvamueller
authored andcommitted
[MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able (scikit-learn#7594)
* FIX Subclass a new MaskedArray which allows pickling even when dype=object * TST unpickling too * FIX Use MaskedArray from utils.fixes rather than from numpy * FIX imports * Don't assign a variable * FIX np --> numpy * Use tostring instead of tobytes for old numpy * COSMIT pickle-able --> picklable * use #noqa comment to turn off flake8 * TST/ENH Check if the pickled est's predict matches with the original one's
1 parent 7955ce0 commit 868a58b

File tree

4 files changed

+48
-7
lines changed

4 files changed

+48
-7
lines changed

sklearn/model_selection/_search.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..utils import check_random_state
3131
from ..utils.fixes import sp_version
3232
from ..utils.fixes import rankdata
33+
from ..utils.fixes import MaskedArray
3334
from ..utils.random import sample_without_replacement
3435
from ..utils.validation import indexable, check_is_fitted
3536
from ..utils.metaestimators import if_delegate_has_method
@@ -611,10 +612,12 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
611612
best_index = np.flatnonzero(results["rank_test_score"] == 1)[0]
612613
best_parameters = candidate_params[best_index]
613614

614-
# Use one np.MaskedArray and mask all the places where the param is not
615+
# Use one MaskedArray and mask all the places where the param is not
615616
# applicable for that candidate. Use defaultdict as each candidate may
616617
# not contain all the params
617-
param_results = defaultdict(partial(np.ma.masked_all, (n_candidates,),
618+
param_results = defaultdict(partial(MaskedArray,
619+
np.empty(n_candidates,),
620+
mask=True,
618621
dtype=object))
619622
for cand_i, params in enumerate(candidate_params):
620623
for name, value in params.items():

sklearn/model_selection/tests/test_search.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,12 +940,16 @@ def test_pickle():
940940
clf = MockClassifier()
941941
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True)
942942
grid_search.fit(X, y)
943-
pickle.dumps(grid_search) # smoke test
943+
grid_search_pickled = pickle.loads(pickle.dumps(grid_search))
944+
assert_array_almost_equal(grid_search.predict(X),
945+
grid_search_pickled.predict(X))
944946

945947
random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]},
946948
refit=True, n_iter=3)
947949
random_search.fit(X, y)
948-
pickle.dumps(random_search) # smoke test
950+
random_search_pickled = pickle.loads(pickle.dumps(random_search))
951+
assert_array_almost_equal(random_search.predict(X),
952+
random_search_pickled.predict(X))
949953

950954

951955
def test_grid_search_with_multioutput_data():

sklearn/utils/fixes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,21 @@ def rankdata(a, method='average'):
401401
return .5 * (count[dense] + count[dense - 1] + 1)
402402
else:
403403
from scipy.stats import rankdata
404+
405+
406+
if np_version < (1, 12, 0):
407+
class MaskedArray(np.ma.MaskedArray):
408+
# Before numpy 1.12, np.ma.MaskedArray object is not picklable
409+
# This fix is needed to make our model_selection.GridSearchCV
410+
# picklable as the ``cv_results_`` param uses MaskedArray
411+
def __getstate__(self):
412+
"""Return the internal state of the masked array, for pickling
413+
purposes.
414+
415+
"""
416+
cf = 'CF'[self.flags.fnc]
417+
data_state = super(np.ma.MaskedArray, self).__reduce__()[2]
418+
return data_state + (np.ma.getmaskarray(self).tostring(cf),
419+
self._fill_value)
420+
else:
421+
from numpy.ma import MaskedArray # noqa

sklearn/utils/tests/test_fixes.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
# Lars Buitinck
44
# License: BSD 3 clause
55

6+
import pickle
67
import numpy as np
78

8-
from numpy.testing import (assert_almost_equal,
9-
assert_array_almost_equal)
9+
from sklearn.utils.testing import assert_equal
10+
from sklearn.utils.testing import assert_false
11+
from sklearn.utils.testing import assert_true
12+
from sklearn.utils.testing import assert_almost_equal
13+
from sklearn.utils.testing import assert_array_equal
14+
from sklearn.utils.testing import assert_array_almost_equal
15+
1016
from sklearn.utils.fixes import divide, expit
1117
from sklearn.utils.fixes import astype
12-
from sklearn.utils.testing import assert_equal, assert_false, assert_true
18+
from sklearn.utils.fixes import MaskedArray
1319

1420

1521
def test_expit():
@@ -50,3 +56,13 @@ def test_astype_copy_memory():
5056

5157
e_int32 = astype(a_int32, dtype=np.int32)
5258
assert_false(np.may_share_memory(e_int32, a_int32))
59+
60+
61+
def test_masked_array_obj_dtype_pickleable():
62+
marr = MaskedArray([1, None, 'a'], dtype=object)
63+
64+
for mask in (True, False, [0, 1, 0]):
65+
marr.mask = mask
66+
marr_pickled = pickle.loads(pickle.dumps(marr))
67+
assert_array_equal(marr.data, marr_pickled.data)
68+
assert_array_equal(marr.mask, marr_pickled.mask)

0 commit comments

Comments
 (0)