Skip to content

Commit 8da5092

Browse files
MechCoderjnothman
authored andcommitted
[MRG] Fail imputer early when number of features are not the same in fit and transform (scikit-learn#7374)
1 parent 12a0125 commit 8da5092

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

sklearn/preprocessing/imputation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,17 @@ def transform(self, X):
309309
"""
310310
if self.axis == 0:
311311
check_is_fitted(self, 'statistics_')
312+
X = check_array(X, accept_sparse='csc', dtype=FLOAT_DTYPES,
313+
force_all_finite=False, copy=self.copy)
314+
statistics = self.statistics_
315+
if X.shape[1] != statistics.shape[0]:
316+
raise ValueError("X has %d features per sample, expected %d"
317+
% (X.shape[1], self.statistics_.shape[0]))
312318

313319
# Since two different arrays can be provided in fit(X) and
314320
# transform(X), the imputation data need to be recomputed
315321
# when the imputation is done per sample
316-
if self.axis == 1:
322+
else:
317323
X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES,
318324
force_all_finite=False, copy=self.copy)
319325

@@ -328,10 +334,6 @@ def transform(self, X):
328334
self.strategy,
329335
self.missing_values,
330336
self.axis)
331-
else:
332-
X = check_array(X, accept_sparse='csc', dtype=FLOAT_DTYPES,
333-
force_all_finite=False, copy=self.copy)
334-
statistics = self.statistics_
335337

336338
# Delete the invalid rows/columns
337339
invalid_mask = np.isnan(statistics)

0 commit comments

Comments
 (0)