66
77Path = list [tuple [int , int ]]
88
9- grid = [
10- [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
11- [0 , 1 , 0 , 0 , 0 , 0 , 0 ], # 0 are free path whereas 1's are obstacles
12- [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
13- [0 , 0 , 1 , 0 , 0 , 0 , 0 ],
14- [1 , 0 , 1 , 0 , 0 , 0 , 0 ],
15- [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
16- [0 , 0 , 0 , 0 , 1 , 0 , 0 ],
9+ # 0's are free path whereas 1's are obstacles
10+ TEST_GRIDS = [
11+ [
12+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
13+ [0 , 1 , 0 , 0 , 0 , 0 , 0 ],
14+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
15+ [0 , 0 , 1 , 0 , 0 , 0 , 0 ],
16+ [1 , 0 , 1 , 0 , 0 , 0 , 0 ],
17+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
18+ [0 , 0 , 0 , 0 , 1 , 0 , 0 ],
19+ ],
20+ [
21+ [0 , 0 , 0 , 1 , 1 , 0 , 0 ],
22+ [0 , 0 , 0 , 0 , 1 , 0 , 1 ],
23+ [0 , 0 , 0 , 1 , 1 , 0 , 0 ],
24+ [0 , 1 , 0 , 0 , 1 , 0 , 0 ],
25+ [1 , 0 , 0 , 1 , 1 , 0 , 1 ],
26+ [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
27+ ],
28+ [
29+ [0 , 0 , 1 , 0 , 0 ],
30+ [0 , 1 , 0 , 0 , 0 ],
31+ [0 , 0 , 1 , 0 , 1 ],
32+ [1 , 0 , 0 , 1 , 1 ],
33+ [0 , 0 , 0 , 0 , 0 ],
34+ ],
1735]
1836
1937delta = ([- 1 , 0 ], [0 , - 1 ], [1 , 0 ], [0 , 1 ]) # up, left, down, right
@@ -65,10 +83,14 @@ def calculate_heuristic(self) -> float:
6583 def __lt__ (self , other ) -> bool :
6684 return self .f_cost < other .f_cost
6785
86+ def __eq__ (self , other ) -> bool :
87+ return self .pos == other .pos
88+
6889
6990class GreedyBestFirst :
7091 """
71- >>> gbf = GreedyBestFirst((0, 0), (len(grid) - 1, len(grid[0]) - 1))
92+ >>> grid = TEST_GRIDS[2]
93+ >>> gbf = GreedyBestFirst(grid, (0, 0), (len(grid) - 1, len(grid[0]) - 1))
7294 >>> [x.pos for x in gbf.get_successors(gbf.start)]
7395 [(1, 0), (0, 1)]
7496 >>> (gbf.start.pos_y + delta[3][0], gbf.start.pos_x + delta[3][1])
@@ -78,11 +100,14 @@ class GreedyBestFirst:
78100 >>> gbf.retrace_path(gbf.start)
79101 [(0, 0)]
80102 >>> gbf.search() # doctest: +NORMALIZE_WHITESPACE
81- [(0, 0), (1, 0), (2, 0), (3, 0 ), (3, 1), (4, 1), (5, 1 ), (6, 1 ),
82- (6, 2), (6, 3), (5, 3), (5, 4), (5, 5), (6, 5), (6, 6 )]
103+ [(0, 0), (1, 0), (2, 0), (2, 1 ), (3, 1), (4, 1), (4, 2 ), (4, 3 ),
104+ (4, 4 )]
83105 """
84106
85- def __init__ (self , start : tuple [int , int ], goal : tuple [int , int ]):
107+ def __init__ (
108+ self , grid : list [list [int ]], start : tuple [int , int ], goal : tuple [int , int ]
109+ ):
110+ self .grid = grid
86111 self .start = Node (start [1 ], start [0 ], goal [1 ], goal [0 ], 0 , None )
87112 self .target = Node (goal [1 ], goal [0 ], goal [1 ], goal [0 ], 99999 , None )
88113
@@ -114,14 +139,6 @@ def search(self) -> Path | None:
114139
115140 if child_node not in self .open_nodes :
116141 self .open_nodes .append (child_node )
117- else :
118- # retrieve the best current path
119- better_node = self .open_nodes .pop (self .open_nodes .index (child_node ))
120-
121- if child_node .g_cost < better_node .g_cost :
122- self .open_nodes .append (child_node )
123- else :
124- self .open_nodes .append (better_node )
125142
126143 if not self .reached :
127144 return [self .start .pos ]
@@ -131,28 +148,22 @@ def get_successors(self, parent: Node) -> list[Node]:
131148 """
132149 Returns a list of successors (both in the grid and free spaces)
133150 """
134- successors = []
135- for action in delta :
136- pos_x = parent .pos_x + action [1 ]
137- pos_y = parent .pos_y + action [0 ]
138-
139- if not (0 <= pos_x <= len (grid [0 ]) - 1 and 0 <= pos_y <= len (grid ) - 1 ):
140- continue
141-
142- if grid [pos_y ][pos_x ] != 0 :
143- continue
144-
145- successors .append (
146- Node (
147- pos_x ,
148- pos_y ,
149- self .target .pos_y ,
150- self .target .pos_x ,
151- parent .g_cost + 1 ,
152- parent ,
153- )
151+ return [
152+ Node (
153+ pos_x ,
154+ pos_y ,
155+ self .target .pos_x ,
156+ self .target .pos_y ,
157+ parent .g_cost + 1 ,
158+ parent ,
159+ )
160+ for action in delta
161+ if (
162+ 0 <= (pos_x := parent .pos_x + action [1 ]) < len (self .grid [0 ])
163+ and 0 <= (pos_y := parent .pos_y + action [0 ]) < len (self .grid )
164+ and self .grid [pos_y ][pos_x ] == 0
154165 )
155- return successors
166+ ]
156167
157168 def retrace_path (self , node : Node | None ) -> Path :
158169 """
@@ -168,18 +179,21 @@ def retrace_path(self, node: Node | None) -> Path:
168179
169180
170181if __name__ == "__main__" :
171- init = (0 , 0 )
172- goal = (len (grid ) - 1 , len (grid [0 ]) - 1 )
173- for elem in grid :
174- print (elem )
175-
176- print ("------" )
177-
178- greedy_bf = GreedyBestFirst (init , goal )
179- path = greedy_bf .search ()
180- if path :
181- for pos_x , pos_y in path :
182- grid [pos_x ][pos_y ] = 2
182+ for idx , grid in enumerate (TEST_GRIDS ):
183+ print (f"==grid-{ idx + 1 } ==" )
183184
185+ init = (0 , 0 )
186+ goal = (len (grid ) - 1 , len (grid [0 ]) - 1 )
184187 for elem in grid :
185188 print (elem )
189+
190+ print ("------" )
191+
192+ greedy_bf = GreedyBestFirst (grid , init , goal )
193+ path = greedy_bf .search ()
194+ if path :
195+ for pos_x , pos_y in path :
196+ grid [pos_x ][pos_y ] = 2
197+
198+ for elem in grid :
199+ print (elem )
0 commit comments