Skip to content

Commit f903dcb

Browse files
authored
Merge pull request #549 from CronosC/tdl_minor_fixes
Addressed some minor issues in the TDL implementation
2 parents c3699dd + 3b362de commit f903dcb

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

ontolearn/learners/tree_learner.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
OWLObjectIntersectionOf,
3131
OWLClassExpression,
3232
OWLObjectUnionOf,
33+
OWLObjectOneOf,
34+
OWLObjectHasValue
3335
)
36+
from owlapy.utils import HasFiller, HasOperands
3437
from owlapy.owl_individual import OWLNamedIndividual
3538
import ontolearn.triple_store
3639
from ontolearn.knowledge_base import KnowledgeBase
@@ -132,19 +135,32 @@ def concepts_reducer(
132135
dl_concept_path = reduced_cls((dl_concept_path, c))
133136
return dl_concept_path
134137

138+
def contains_nominal(expr: OWLClassExpression) -> bool:
139+
"""Returns True if the OWL expression contains a nominal (OWLObjectOneOf, OWLDataOneOf)."""
140+
if isinstance(expr, (OWLObjectOneOf, OWLObjectHasValue)):
141+
return True
142+
143+
if isinstance(expr, HasFiller):
144+
return contains_nominal(expr.get_filler())
145+
146+
if isinstance(expr, HasOperands):
147+
return any(contains_nominal(op) for op in expr.get_operands())
148+
149+
return False
135150

136151
class TDL:
137152
"""Tree-based Description Logic Concept Learner"""
138153

139154
def __init__(self, knowledge_base,
140155
use_inverse: bool = False,
141156
use_data_properties: bool = False,
142-
use_nominals: bool = False,
157+
use_nominals: bool = True,
143158
use_card_restrictions: bool = False,
144159
kwargs_classifier: dict = None,
145160
max_runtime: int = 1,
146161
grid_search_over: dict = None,
147162
grid_search_apply: bool = False,
163+
kwargs_grid_search: dict = {},
148164
report_classification: bool = True,
149165
plot_tree: bool = False,
150166
plot_embeddings: bool = False,
@@ -166,15 +182,21 @@ def __init__(self, knowledge_base,
166182
"min_samples_leaf": [1, 2, 3, 4, 5, 10],
167183
"max_depth": [1, 2, 3, 4, 5, 10, None],
168184
}
185+
elif grid_search_apply and grid_search_over is not None:
186+
pass
169187
else:
170188
grid_search_over = dict()
189+
190+
kwargs_grid_search.setdefault("cv", 10)
191+
171192
assert (
172193
isinstance(knowledge_base, KnowledgeBase)
173194
or isinstance(knowledge_base, ontolearn.triple_store.TripleStore)
174195
or isinstance(knowledge_base)
175196
), "knowledge_base must be a KnowledgeBase instance"
176197
print(f"Knowledge Base: {knowledge_base}")
177198
self.grid_search_over = grid_search_over
199+
self.kwargs_grid_search = kwargs_grid_search
178200
self.knowledge_base = knowledge_base
179201
self.report_classification = report_classification
180202
self.plot_tree = plot_tree
@@ -209,11 +231,12 @@ def extract_expressions_from_owl_individuals(self, individuals: List[OWLNamedInd
209231
verbose=self.verbose,
210232
desc="Extracting information about examples"):
211233
for owl_class_expression in self.knowledge_base.abox(individual=owl_named_individual, mode="expression"):
212-
str_dl_concept=owl_expression_to_dl(owl_class_expression)
213-
individuals_to_feature_mapping.setdefault(owl_named_individual.str,set()).add(str_dl_concept)
214-
if str_dl_concept not in features:
215-
# A mapping from str dl representation to owl object.
216-
features[str_dl_concept] = owl_class_expression
234+
if self.use_nominals or not contains_nominal(owl_class_expression):
235+
str_dl_concept=owl_expression_to_dl(owl_class_expression)
236+
individuals_to_feature_mapping.setdefault(owl_named_individual.str,set()).add(str_dl_concept)
237+
if str_dl_concept not in features:
238+
# A mapping from str dl representation to owl object.
239+
features[str_dl_concept] = owl_class_expression
217240

218241
assert len(features) > 0, "First hop features cannot be extracted. Ensure that there are axioms about the examples."
219242
if self.verbose > 0:
@@ -372,7 +395,7 @@ def fit(self, learning_problem: PosNegLPStandard = None, max_runtime: int = None
372395
if self.grid_search_over:
373396
grid_search = sklearn.model_selection.GridSearchCV(
374397
tree.DecisionTreeClassifier(**self.kwargs_classifier),
375-
param_grid=self.grid_search_over, cv=10, ).fit(X.values, y.values)
398+
param_grid=self.grid_search_over, **self.kwargs_grid_search).fit(X.values, y.values)
376399
print(grid_search.best_params_)
377400
self.kwargs_classifier.update(grid_search.best_params_)
378401
# Training

0 commit comments

Comments
 (0)