diff --git a/search.py b/search.py index 689671769..89f872079 100644 --- a/search.py +++ b/search.py @@ -327,9 +327,11 @@ def iterative_deepening_search(problem): # Pseudocode from https://webdocs.cs.ualberta.ca/%7Eholte/Publications/MM-AAAI2016.pdf def bidirectional_search(problem): - e = problem.find_min_edge() - gF, gB = {problem.initial: 0}, {problem.goal: 0} - openF, openB = [problem.initial], [problem.goal] + e = 0 + if isinstance(problem, GraphProblem): + e = problem.find_min_edge() + gF, gB = {Node(problem.initial): 0}, {Node(problem.goal): 0} + openF, openB = [Node(problem.initial)], [Node(problem.goal)] closedF, closedB = [], [] U = np.inf @@ -340,14 +342,14 @@ def extend(U, open_dir, open_other, g_dir, g_other, closed_dir): open_dir.remove(n) closed_dir.append(n) - for c in problem.actions(n): + for c in n.expand(problem): if c in open_dir or c in closed_dir: - if g_dir[c] <= problem.path_cost(g_dir[n], n, None, c): + if g_dir[c] <= problem.path_cost(g_dir[n], n.state, None, c.state): continue open_dir.remove(c) - g_dir[c] = problem.path_cost(g_dir[n], n, None, c) + g_dir[c] = problem.path_cost(g_dir[n], n.state, None, c.state) open_dir.append(c) if c in open_other: @@ -372,15 +374,15 @@ def find_key(pr_min, open_dir, g): """Finds key in open_dir with value equal to pr_min and minimum g value.""" m = np.inf - state = -1 + node = Node(-1) for n in open_dir: pr = max(g[n] + problem.h(n), 2 * g[n]) if pr == pr_min: if g[n] < m: m = g[n] - state = n + node = n - return state + return node while openF and openB: pr_min_f, f_min_f, g_min_f = find_min(openF, gF) diff --git a/tests/test_search.py b/tests/test_search.py index d37f8fa38..075a57312 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -71,6 +71,8 @@ def test_depth_limited_search(): def test_bidirectional_search(): assert bidirectional_search(romania_problem) == 418 + assert bidirectional_search(eight_puzzle) == 12 + assert bidirectional_search(EightPuzzle((1, 2, 3, 4, 5, 6, 0, 7, 8))) == 2 def test_astar_search():