22BaseNode Module
33"""
44
5- from abc import ABC , abstractmethod
6- from typing import Optional , List
75import re
6+ from abc import ABC , abstractmethod
7+ from typing import List , Optional
8+
9+ from ..utils import get_logger
810
911
1012class BaseNode (ABC ):
@@ -14,10 +16,11 @@ class BaseNode(ABC):
1416 Attributes:
1517 node_name (str): The unique identifier name for the node.
1618 input (str): Boolean expression defining the input keys needed from the state.
17- output (List[str]): List of
19+ output (List[str]): List of
1820 min_input_len (int): Minimum required number of input keys.
1921 node_config (Optional[dict]): Additional configuration for the node.
20-
22+ logger (logging.Logger): The centralized root logger
23+
2124 Args:
2225 node_name (str): Name for identifying the node.
2326 node_type (str): Type of the node; must be 'node' or 'conditional_node'.
@@ -28,7 +31,7 @@ class BaseNode(ABC):
2831
2932 Raises:
3033 ValueError: If `node_type` is not one of the allowed types.
31-
34+
3235 Example:
3336 >>> class MyNode(BaseNode):
3437 ... def execute(self, state):
@@ -40,18 +43,27 @@ class BaseNode(ABC):
4043 {'key': 'value'}
4144 """
4245
43- def __init__ (self , node_name : str , node_type : str , input : str , output : List [str ],
44- min_input_len : int = 1 , node_config : Optional [dict ] = None ):
46+ def __init__ (
47+ self ,
48+ node_name : str ,
49+ node_type : str ,
50+ input : str ,
51+ output : List [str ],
52+ min_input_len : int = 1 ,
53+ node_config : Optional [dict ] = None ,
54+ ):
4555
4656 self .node_name = node_name
4757 self .input = input
4858 self .output = output
4959 self .min_input_len = min_input_len
5060 self .node_config = node_config
61+ self .logger = get_logger ()
5162
5263 if node_type not in ["node" , "conditional_node" ]:
5364 raise ValueError (
54- f"node_type must be 'node' or 'conditional_node', got '{ node_type } '" )
65+ f"node_type must be 'node' or 'conditional_node', got '{ node_type } '"
66+ )
5567 self .node_type = node_type
5668
5769 @abstractmethod
@@ -102,8 +114,7 @@ def get_input_keys(self, state: dict) -> List[str]:
102114 self ._validate_input_keys (input_keys )
103115 return input_keys
104116 except ValueError as e :
105- raise ValueError (
106- f"Error parsing input keys for { self .node_name } : { str (e )} " )
117+ raise ValueError (f"Error parsing input keys for { self .node_name } : { str (e )} " )
107118
108119 def _validate_input_keys (self , input_keys ):
109120 """
@@ -119,7 +130,8 @@ def _validate_input_keys(self, input_keys):
119130 if len (input_keys ) < self .min_input_len :
120131 raise ValueError (
121132 f"""{ self .node_name } requires at least { self .min_input_len } input keys,
122- got { len (input_keys )} .""" )
133+ got { len (input_keys )} ."""
134+ )
123135
124136 def _parse_input_keys (self , state : dict , expression : str ) -> List [str ]:
125137 """
@@ -142,67 +154,80 @@ def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
142154 raise ValueError ("Empty expression." )
143155
144156 # Check for adjacent state keys without an operator between them
145- pattern = r'\b(' + '|' .join (re .escape (key ) for key in state .keys ()) + \
146- r')(\b\s*\b)(' + '|' .join (re .escape (key )
147- for key in state .keys ()) + r')\b'
157+ pattern = (
158+ r"\b("
159+ + "|" .join (re .escape (key ) for key in state .keys ())
160+ + r")(\b\s*\b)("
161+ + "|" .join (re .escape (key ) for key in state .keys ())
162+ + r")\b"
163+ )
148164 if re .search (pattern , expression ):
149165 raise ValueError (
150- "Adjacent state keys found without an operator between them." )
166+ "Adjacent state keys found without an operator between them."
167+ )
151168
152169 # Remove spaces
153170 expression = expression .replace (" " , "" )
154171
155172 # Check for operators with empty adjacent tokens or at the start/end
156- if expression [0 ] in '&|' or expression [- 1 ] in '&|' \
157- or '&&' in expression or '||' in expression or \
158- '&|' in expression or '|&' in expression :
173+ if (
174+ expression [0 ] in "&|"
175+ or expression [- 1 ] in "&|"
176+ or "&&" in expression
177+ or "||" in expression
178+ or "&|" in expression
179+ or "|&" in expression
180+ ):
159181 raise ValueError ("Invalid operator usage." )
160182
161183 # Check for balanced parentheses and valid operator placement
162184 open_parentheses = close_parentheses = 0
163185 for i , char in enumerate (expression ):
164- if char == '(' :
186+ if char == "(" :
165187 open_parentheses += 1
166- elif char == ')' :
188+ elif char == ")" :
167189 close_parentheses += 1
168190 # Check for invalid operator sequences
169191 if char in "&|" and i + 1 < len (expression ) and expression [i + 1 ] in "&|" :
170192 raise ValueError (
171- "Invalid operator placement: operators cannot be adjacent." )
193+ "Invalid operator placement: operators cannot be adjacent."
194+ )
172195
173196 # Check for missing or balanced parentheses
174197 if open_parentheses != close_parentheses :
175- raise ValueError (
176- "Missing or unbalanced parentheses in expression." )
198+ raise ValueError ("Missing or unbalanced parentheses in expression." )
177199
178200 # Helper function to evaluate an expression without parentheses
179201 def evaluate_simple_expression (exp : str ) -> List [str ]:
180202 """Evaluate an expression without parentheses."""
181203
182204 # Split the expression by the OR operator and process each segment
183- for or_segment in exp .split ('|' ):
205+ for or_segment in exp .split ("|" ):
184206
185207 # Check if all elements in an AND segment are in state
186- and_segment = or_segment .split ('&' )
208+ and_segment = or_segment .split ("&" )
187209 if all (elem .strip () in state for elem in and_segment ):
188- return [elem .strip () for elem in and_segment if elem .strip () in state ]
210+ return [
211+ elem .strip () for elem in and_segment if elem .strip () in state
212+ ]
189213 return []
190214
191215 # Helper function to evaluate expressions with parentheses
192216 def evaluate_expression (expression : str ) -> List [str ]:
193217 """Evaluate an expression with parentheses."""
194-
195- while '(' in expression :
196- start = expression .rfind ('(' )
197- end = expression .find (')' , start )
198- sub_exp = expression [start + 1 : end ]
218+
219+ while "(" in expression :
220+ start = expression .rfind ("(" )
221+ end = expression .find (")" , start )
222+ sub_exp = expression [start + 1 : end ]
199223
200224 # Replace the evaluated part with a placeholder and then evaluate it
201225 sub_result = evaluate_simple_expression (sub_exp )
202226
203227 # For simplicity in handling, join sub-results with OR to reprocess them later
204- expression = expression [:start ] + \
205- '|' .join (sub_result ) + expression [end + 1 :]
228+ expression = (
229+ expression [:start ] + "|" .join (sub_result ) + expression [end + 1 :]
230+ )
206231 return evaluate_simple_expression (expression )
207232
208233 result = evaluate_expression (expression )
0 commit comments