@@ -309,11 +309,17 @@ def transform(self, X):
309309 """
310310 if self .axis == 0 :
311311 check_is_fitted (self , 'statistics_' )
312+ X = check_array (X , accept_sparse = 'csc' , dtype = FLOAT_DTYPES ,
313+ force_all_finite = False , copy = self .copy )
314+ statistics = self .statistics_
315+ if X .shape [1 ] != statistics .shape [0 ]:
316+ raise ValueError ("X has %d features per sample, expected %d"
317+ % (X .shape [1 ], self .statistics_ .shape [0 ]))
312318
313319 # Since two different arrays can be provided in fit(X) and
314320 # transform(X), the imputation data need to be recomputed
315321 # when the imputation is done per sample
316- if self . axis == 1 :
322+ else :
317323 X = check_array (X , accept_sparse = 'csr' , dtype = FLOAT_DTYPES ,
318324 force_all_finite = False , copy = self .copy )
319325
@@ -328,10 +334,6 @@ def transform(self, X):
328334 self .strategy ,
329335 self .missing_values ,
330336 self .axis )
331- else :
332- X = check_array (X , accept_sparse = 'csc' , dtype = FLOAT_DTYPES ,
333- force_all_finite = False , copy = self .copy )
334- statistics = self .statistics_
335337
336338 # Delete the invalid rows/columns
337339 invalid_mask = np .isnan (statistics )
0 commit comments