Skip to content

Commit debfcdf

Browse files
committed
rebase: rebase to main
1 parent 8bd86c2 commit debfcdf

File tree

5 files changed

+25
-14
lines changed

5 files changed

+25
-14
lines changed

onedal/neighbors/neighbors.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,6 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
170170

171171
if X is not None:
172172
query_is_train = False
173-
<<<<<<< HEAD
174-
<<<<<<< HEAD
175-
if not use_raw_input:
176-
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
177-
=======
178-
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
179-
>>>>>>> e003b37f (fix: try it again)
180-
=======
181-
>>>>>>> 8cd6f2b2 (fix: first round of refactor move preprocssing function to sklearnex)
182173
else:
183174
query_is_train = True
184175
X = self._fit_X
@@ -517,4 +508,4 @@ def fit(self, X, y, queue=None):
517508

518509
@supports_queue
519510
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
520-
return self._kneighbors(X, n_neighbors, return_distance)
511+
return self._kneighbors(X, n_neighbors, return_distance)

sklearnex/neighbors/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,4 +476,4 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
476476

477477
return kneighbors_graph
478478

479-
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__
479+
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__

sklearnex/neighbors/knn_classification.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,26 @@ def _onedal_fit(self, X, y, queue=None):
174174
# print("=" * 50, file=sys.stderr, flush=True)
175175
# print("DEBUG: _onedal_fit called!", file=sys.stderr, flush=True)
176176
# print("=" * 50, file=sys.stderr, flush=True)
177+
178+
# Perform preprocessing at sklearnex level
179+
X, y = self._validate_data(
180+
X, y, dtype=[np.float64, np.float32], accept_sparse="csr"
181+
)
182+
183+
# Validate n_neighbors
184+
self._validate_n_neighbors(self.n_neighbors)
185+
186+
# Parse auto method
187+
self._fit_method = self._parse_auto_method(self.algorithm, X.shape[0], X.shape[1])
188+
189+
# Validate classification targets
190+
from onedal.utils.validation import _check_classification_targets
191+
192+
_check_classification_targets(y)
193+
194+
# Handle shape and class processing at sklearnex level
195+
y = self._process_classification_targets(y)
196+
177197
onedal_params = {
178198
"n_neighbors": self.n_neighbors,
179199
"weights": self.weights,
@@ -230,4 +250,4 @@ def _save_attributes(self):
230250
predict.__doc__ = _sklearn_KNeighborsClassifier.predict.__doc__
231251
predict_proba.__doc__ = _sklearn_KNeighborsClassifier.predict_proba.__doc__
232252
score.__doc__ = _sklearn_KNeighborsClassifier.score.__doc__
233-
kneighbors.__doc__ = _sklearn_KNeighborsClassifier.kneighbors.__doc__
253+
kneighbors.__doc__ = _sklearn_KNeighborsClassifier.kneighbors.__doc__

sklearnex/neighbors/knn_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,4 @@ def _save_attributes(self):
209209
fit.__doc__ = _sklearn_KNeighborsRegressor.__doc__
210210
predict.__doc__ = _sklearn_KNeighborsRegressor.predict.__doc__
211211
kneighbors.__doc__ = _sklearn_KNeighborsRegressor.kneighbors.__doc__
212-
score.__doc__ = _sklearn_KNeighborsRegressor.score.__doc__
212+
score.__doc__ = _sklearn_KNeighborsRegressor.score.__doc__

sklearnex/neighbors/knn_unsupervised.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,4 @@ def _save_attributes(self):
195195
radius_neighbors.__doc__ = _sklearn_NearestNeighbors.radius_neighbors.__doc__
196196
radius_neighbors_graph.__doc__ = (
197197
_sklearn_NearestNeighbors.radius_neighbors_graph.__doc__
198-
)
198+
)

0 commit comments

Comments
 (0)