Skip to content

Commit fa09f8b

Browse files
author
cwork
committed
TDL: added kwargs_grid_search parameter, which is passed to the GridSearchCV class. Gives the user more freedom in setting up the GridSearch. For cv, which was manually set before the standart value is set to 10 s.t. previous behavior is preserved. (resolves #546)
1 parent 3497c8d commit fa09f8b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

ontolearn/learners/tree_learner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(self, knowledge_base,
157157
max_runtime: int = 1,
158158
grid_search_over: dict = None,
159159
grid_search_apply: bool = False,
160+
kwargs_grid_search: dict = None,
160161
report_classification: bool = True,
161162
plot_tree: bool = False,
162163
plot_embeddings: bool = False,
@@ -182,13 +183,17 @@ def __init__(self, knowledge_base,
182183
pass
183184
else:
184185
grid_search_over = dict()
186+
187+
kwargs_grid_search.setdefault("cv", 10)
188+
185189
assert (
186190
isinstance(knowledge_base, KnowledgeBase)
187191
or isinstance(knowledge_base, ontolearn.triple_store.TripleStore)
188192
or isinstance(knowledge_base)
189193
), "knowledge_base must be a KnowledgeBase instance"
190194
print(f"Knowledge Base: {knowledge_base}")
191195
self.grid_search_over = grid_search_over
196+
self.kwargs_grid_search = kwargs_grid_search
192197
self.knowledge_base = knowledge_base
193198
self.report_classification = report_classification
194199
self.plot_tree = plot_tree
@@ -387,7 +392,7 @@ def fit(self, learning_problem: PosNegLPStandard = None, max_runtime: int = None
387392
if self.grid_search_over:
388393
grid_search = sklearn.model_selection.GridSearchCV(
389394
tree.DecisionTreeClassifier(**self.kwargs_classifier),
390-
param_grid=self.grid_search_over, cv=10, ).fit(X.values, y.values)
395+
param_grid=self.grid_search_over, **self.kwargs_grid_search).fit(X.values, y.values)
391396
print(grid_search.best_params_)
392397
self.kwargs_classifier.update(grid_search.best_params_)
393398
# Training

0 commit comments

Comments
 (0)