Skip to content

Commit 69b6a46

Browse files
tirthasheshpatelantmarakis
authored andcommittedJan 8, 2020
[WIP] ENH: add support for all types of problems in Bidirectional Search (#1147)
* ENH: all problems can now use BS * TST: add test for all types of problems for BS
1 parent ec2111a commit 69b6a46

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed
 

‎search.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,11 @@ def iterative_deepening_search(problem):
327327
# Pseudocode from https://webdocs.cs.ualberta.ca/%7Eholte/Publications/MM-AAAI2016.pdf
328328

329329
def bidirectional_search(problem):
330-
e = problem.find_min_edge()
331-
gF, gB = {problem.initial: 0}, {problem.goal: 0}
332-
openF, openB = [problem.initial], [problem.goal]
330+
e = 0
331+
if isinstance(problem, GraphProblem):
332+
e = problem.find_min_edge()
333+
gF, gB = {Node(problem.initial): 0}, {Node(problem.goal): 0}
334+
openF, openB = [Node(problem.initial)], [Node(problem.goal)]
333335
closedF, closedB = [], []
334336
U = np.inf
335337

@@ -340,14 +342,14 @@ def extend(U, open_dir, open_other, g_dir, g_other, closed_dir):
340342
open_dir.remove(n)
341343
closed_dir.append(n)
342344

343-
for c in problem.actions(n):
345+
for c in n.expand(problem):
344346
if c in open_dir or c in closed_dir:
345-
if g_dir[c] <= problem.path_cost(g_dir[n], n, None, c):
347+
if g_dir[c] <= problem.path_cost(g_dir[n], n.state, None, c.state):
346348
continue
347349

348350
open_dir.remove(c)
349351

350-
g_dir[c] = problem.path_cost(g_dir[n], n, None, c)
352+
g_dir[c] = problem.path_cost(g_dir[n], n.state, None, c.state)
351353
open_dir.append(c)
352354

353355
if c in open_other:
@@ -372,15 +374,15 @@ def find_key(pr_min, open_dir, g):
372374
"""Finds key in open_dir with value equal to pr_min
373375
and minimum g value."""
374376
m = np.inf
375-
state = -1
377+
node = Node(-1)
376378
for n in open_dir:
377379
pr = max(g[n] + problem.h(n), 2 * g[n])
378380
if pr == pr_min:
379381
if g[n] < m:
380382
m = g[n]
381-
state = n
383+
node = n
382384

383-
return state
385+
return node
384386

385387
while openF and openB:
386388
pr_min_f, f_min_f, g_min_f = find_min(openF, gF)

‎tests/test_search.py

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def test_depth_limited_search():
7171

7272
def test_bidirectional_search():
7373
assert bidirectional_search(romania_problem) == 418
74+
assert bidirectional_search(eight_puzzle) == 12
75+
assert bidirectional_search(EightPuzzle((1, 2, 3, 4, 5, 6, 0, 7, 8))) == 2
7476

7577

7678
def test_astar_search():

0 commit comments

Comments
 (0)
Please sign in to comment.