@@ -174,6 +174,26 @@ def _onedal_fit(self, X, y, queue=None):
174174 # print("=" * 50, file=sys.stderr, flush=True)
175175 # print("DEBUG: _onedal_fit called!", file=sys.stderr, flush=True)
176176 # print("=" * 50, file=sys.stderr, flush=True)
177+
178+ # Perform preprocessing at sklearnex level
179+ X , y = self ._validate_data (
180+ X , y , dtype = [np .float64 , np .float32 ], accept_sparse = "csr"
181+ )
182+
183+ # Validate n_neighbors
184+ self ._validate_n_neighbors (self .n_neighbors )
185+
186+ # Parse auto method
187+ self ._fit_method = self ._parse_auto_method (self .algorithm , X .shape [0 ], X .shape [1 ])
188+
189+ # Validate classification targets
190+ from onedal .utils .validation import _check_classification_targets
191+
192+ _check_classification_targets (y )
193+
194+ # Handle shape and class processing at sklearnex level
195+ y = self ._process_classification_targets (y )
196+
177197 onedal_params = {
178198 "n_neighbors" : self .n_neighbors ,
179199 "weights" : self .weights ,
@@ -230,4 +250,4 @@ def _save_attributes(self):
230250 predict .__doc__ = _sklearn_KNeighborsClassifier .predict .__doc__
231251 predict_proba .__doc__ = _sklearn_KNeighborsClassifier .predict_proba .__doc__
232252 score .__doc__ = _sklearn_KNeighborsClassifier .score .__doc__
233- kneighbors .__doc__ = _sklearn_KNeighborsClassifier .kneighbors .__doc__
253+ kneighbors .__doc__ = _sklearn_KNeighborsClassifier .kneighbors .__doc__
0 commit comments