77from scipy .stats import mode
88
99class DecisionTree :
10- def __init__ (self , min_samples = 11 , method = 'c4.5' , max_depth = - 1 ):
10+ def __init__ (self ,feature_names = None , min_samples = 11 , method = 'c4.5' , max_depth = - 1 ):
1111 self .min_sample = min_samples
1212 self .method = method
1313 self .max_depth = max_depth
1414 self .__rec_fit = self .rec_fit_c4 if method == 'c4.5' else self .rec_fit_cart
1515 self .__rec_predict = self .rec_predict_c4 if method == 'c4.5' else self .rec_predict_cart
1616 self .__rec_rules = self .rec_rules_c4 if method == 'c4.5' else self .rec_rules_cart
17+ self .tree = None
18+ self .feature_names = feature_names
1719
1820 def _entropy (self , labels ):
1921 _ , c = np .unique (labels , return_counts = True )
@@ -149,4 +151,69 @@ def get_rules(self):
149151 self .__rec_rules (self .tree , '' )
150152 new_rules = [' ' .join (x .split (' ' )[2 :]) for x in self .rules ]
151153 self .rules = new_rules
152- return new_rules
154+ return new_rules
155+
156+ def build_tree_structure (self , rules_dict ):
157+ tree = {}
158+ for path , rules in rules_dict .items ():
159+ current_level = tree
160+ parts = path .strip ().split ('\n ' )
161+ for part in parts :
162+ if part not in current_level :
163+ current_level [part ] = {}
164+ current_level = current_level [part ]
165+ # Add rules to the final part of the path
166+ current_level ['rules' ] = rules
167+ return tree
168+
169+ def render_tree (self , tree , depth = 0 ):
170+ lines = []
171+ for part , subtree in sorted (tree .items ()):
172+ if part == 'rules' :
173+ for rule in sorted (subtree ):
174+ lines .append (f"{ '| ' * depth } { rule } " )
175+ else :
176+ lines .append (f"{ '| ' * depth } { part } " )
177+ lines .extend (self .render_tree (subtree , depth + 1 ))
178+ return lines
179+
180+ def extract_rules (self ):
181+ # Collect all rules starting from the root
182+ all_rules = self .rec_rules (self .tree )
183+ # Build a tree structure from the paths and rules
184+ tree = self .build_tree_structure (all_rules )
185+ # Render the tree into a formatted list of lines
186+ return '\n ' .join (self .render_tree (tree ))
187+
188+ def rec_rules (self , node , depth = 0 , path = '' , current_rules = None ):
189+ if current_rules is None :
190+ current_rules = {}
191+
192+ if isinstance (node , tuple ): # Leaf node check
193+ rule = f"{ '| ' * depth } |--- class: { int (node [0 ])} ({ node [1 ]:.2f} )"
194+ if path in current_rules :
195+ current_rules [path ].add (rule )
196+ else :
197+ current_rules [path ] = {rule }
198+ else :
199+ for feature , subtrees in node .items ():
200+ for condition , subtree in subtrees .items ():
201+ cond_type , split_value = condition
202+ condition_str = ""
203+ if self .method == "cart" :
204+ if type (cond_type ) == int :
205+ condition_str = "IS" if cond_type == 1 else "IS NOT"
206+ elif type (cond_type ) == str :
207+ condition_str = cond_type
208+ elif self .method == "c4.5" :
209+ if not cond_type :
210+ condition_str = "IS"
211+ else :
212+ condition_str = str (cond_type )
213+
214+ new_path_part = f"{ '| ' * depth } |--- { self .feature_names [feature ]} { condition_str } { split_value :.2f} "
215+ new_path = path + ('\n ' if path else '' ) + new_path_part
216+
217+ self .rec_rules (subtree , depth + 1 , new_path , current_rules )
218+
219+ return current_rules
0 commit comments