Skip to content

Commit 911792b

Browse files
NicolasHugjnothman
authored andcommitted
ENH Avoid calling _encode_check_unknown() twice in BaseEncoder.transform (scikit-learn#13810)
1 parent 0bfa52d commit 911792b

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ def _transform(self, X, handle_unknown='error'):
134134
Xi = Xi.copy()
135135

136136
Xi[~valid_mask] = self.categories_[i][0]
137-
_, encoded = _encode(Xi, self.categories_[i], encode=True)
137+
# We use check_unknown=False, since _encode_check_unknown was
138+
# already called above.
139+
_, encoded = _encode(Xi, self.categories_[i], encode=True,
140+
check_unknown=False)
138141
X_int[:, i] = encoded
139142

140143
return X_int, X_mask

sklearn/preprocessing/label.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
]
3434

3535

36-
def _encode_numpy(values, uniques=None, encode=False):
36+
def _encode_numpy(values, uniques=None, encode=False, check_unknown=True):
3737
# only used in _encode below, see docstring there for details
3838
if uniques is None:
3939
if encode:
@@ -43,10 +43,11 @@ def _encode_numpy(values, uniques=None, encode=False):
4343
# unique sorts
4444
return np.unique(values)
4545
if encode:
46-
diff = _encode_check_unknown(values, uniques)
47-
if diff:
48-
raise ValueError("y contains previously unseen labels: %s"
49-
% str(diff))
46+
if check_unknown:
47+
diff = _encode_check_unknown(values, uniques)
48+
if diff:
49+
raise ValueError("y contains previously unseen labels: %s"
50+
% str(diff))
5051
encoded = np.searchsorted(uniques, values)
5152
return uniques, encoded
5253
else:
@@ -70,7 +71,7 @@ def _encode_python(values, uniques=None, encode=False):
7071
return uniques
7172

7273

73-
def _encode(values, uniques=None, encode=False):
74+
def _encode(values, uniques=None, encode=False, check_unknown=True):
7475
"""Helper function to factorize (find uniques) and encode values.
7576
7677
Uses pure python method for object dtype, and numpy method for
@@ -90,6 +91,12 @@ def _encode(values, uniques=None, encode=False):
9091
already have been determined in fit).
9192
encode : bool, default False
9293
If True, also encode the values into integer codes based on `uniques`.
94+
check_unknown : bool, default True
95+
If True, check for values in ``values`` that are not in ``unique``
96+
and raise an error. This is ignored for object dtype, and treated as
97+
True in this case. This parameter is useful for
98+
_BaseEncoder._transform() to avoid calling _encode_check_unknown()
99+
twice.
93100
94101
Returns
95102
-------
@@ -107,7 +114,8 @@ def _encode(values, uniques=None, encode=False):
107114
raise TypeError("argument must be a string or number")
108115
return res
109116
else:
110-
return _encode_numpy(values, uniques, encode)
117+
return _encode_numpy(values, uniques, encode,
118+
check_unknown=check_unknown)
111119

112120

113121
def _encode_check_unknown(values, uniques, return_mask=False):

sklearn/preprocessing/tests/test_label.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,24 @@ def test_encode_util(values, expected):
605605
assert_array_equal(encoded, np.array([1, 0, 2, 0, 2]))
606606
_, encoded = _encode(values, uniques, encode=True)
607607
assert_array_equal(encoded, np.array([1, 0, 2, 0, 2]))
608+
609+
610+
def test_encode_check_unknown():
611+
# test for the check_unknown parameter of _encode()
612+
uniques = np.array([1, 2, 3])
613+
values = np.array([1, 2, 3, 4])
614+
615+
# Default is True, raise error
616+
with pytest.raises(ValueError,
617+
match='y contains previously unseen labels'):
618+
_encode(values, uniques, encode=True, check_unknown=True)
619+
620+
# dont raise error if False
621+
_encode(values, uniques, encode=True, check_unknown=False)
622+
623+
# parameter is ignored for object dtype
624+
uniques = np.array(['a', 'b', 'c'], dtype=object)
625+
values = np.array(['a', 'b', 'c', 'd'], dtype=object)
626+
with pytest.raises(ValueError,
627+
match='y contains previously unseen labels'):
628+
_encode(values, uniques, encode=True, check_unknown=False)

0 commit comments

Comments
 (0)