Skip to content

Commit 8981c7e

Browse files
committed
Simplify _get_ancestors().
Per @lucianopaz. remove the initial special case/
1 parent f76df8c commit 8981c7e

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

pymc3/model_graph.py

+19-25
Original file line numberDiff line numberDiff line change
@@ -54,31 +54,25 @@ def _get_ancestors(self, var, func):
5454
vars: List[var] = set(self.var_list)
5555
vars.remove(var)
5656

57-
upstream = self._ancestors(var, func)
58-
59-
# Usual case
60-
if upstream == self._ancestors(var, func, blockers=upstream):
61-
return upstream
62-
else: # deterministic accounting
63-
blockers = set()
64-
retval = set()
65-
def _expand(node) -> Optional[Iterator[Tensor]]:
66-
if node in blockers:
67-
return None
68-
elif node in vars:
69-
blockers.add(node)
70-
retval.add(node)
71-
return None
72-
elif node.owner:
73-
blockers.add(node)
74-
return reversed(node.owner.inputs)
75-
else:
76-
return None
77-
78-
stack_search(start = deque([func]),
79-
expand=_expand,
80-
mode='bfs')
81-
return retval
57+
blockers = set()
58+
retval = set()
59+
def _expand(node) -> Optional[Iterator[Tensor]]:
60+
if node in blockers:
61+
return None
62+
elif node in vars:
63+
blockers.add(node)
64+
retval.add(node)
65+
return None
66+
elif node.owner:
67+
blockers.add(node)
68+
return reversed(node.owner.inputs)
69+
else:
70+
return None
71+
72+
stack_search(start = deque([func]),
73+
expand=_expand,
74+
mode='bfs')
75+
return retval
8276

8377
def _filter_parents(self, var, parents):
8478
"""Get direct parents of a var, as strings"""

0 commit comments

Comments
 (0)