Skip to content

Commit f76df8c

Browse files
committed
Revise _get_ancestors to use BFS.
Previously, `get_ancestors()` used a powerset computation, which would fail on large models. Replaced with breadth-first search.
1 parent d113e41 commit f76df8c

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

pymc3/model_graph.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import itertools
2+
from collections import deque
3+
from typing import Iterator, Optional
24

3-
from theano.gof.graph import ancestors
5+
from theano.gof.graph import ancestors, stack_search
46
from theano.compile import SharedVariable
7+
from theano.tensor import Tensor
58

69
from .util import get_default_varnames
710
import pymc3 as pm
811

12+
# this is a placeholder for a better characterization of the type
13+
# of variables in a model.
14+
RV = Tensor
15+
916

1017
def powerset(iterable):
1118
"""All *nonempty* subsets of an iterable.
@@ -27,7 +34,7 @@ def __init__(self, model):
2734
self._deterministics = None
2835

2936
def get_deterministics(self, var):
30-
"""Compute the deterministic nodes of the graph"""
37+
"""Compute the deterministic nodes of the graph, **not** including var itself."""
3138
deterministics = []
3239
attrs = ('transformed', 'logpt')
3340
for v in self.var_list:
@@ -40,24 +47,38 @@ def _ancestors(self, var, func, blockers=None):
4047
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])
4148

4249
def _get_ancestors(self, var, func):
43-
"""Get all ancestors of a function, doing some accounting for deterministics
44-
45-
Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
46-
return only the inputs *to the deterministic*. However, if we pass in the
47-
deterministic as a blocker, it will skip those nodes.
50+
"""Get all ancestors of a function, doing some accounting for deterministics.
4851
"""
49-
deterministics = self.get_deterministics(var)
52+
53+
# this contains all of the variables in the model EXCEPT var...
54+
vars: List[var] = set(self.var_list)
55+
vars.remove(var)
56+
5057
upstream = self._ancestors(var, func)
5158

5259
# Usual case
5360
if upstream == self._ancestors(var, func, blockers=upstream):
5461
return upstream
5562
else: # deterministic accounting
56-
for d in powerset(upstream):
57-
blocked = self._ancestors(var, func, blockers=d)
58-
if set(d) == blocked:
59-
return d
60-
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')
63+
blockers = set()
64+
retval = set()
65+
def _expand(node) -> Optional[Iterator[Tensor]]:
66+
if node in blockers:
67+
return None
68+
elif node in vars:
69+
blockers.add(node)
70+
retval.add(node)
71+
return None
72+
elif node.owner:
73+
blockers.add(node)
74+
return reversed(node.owner.inputs)
75+
else:
76+
return None
77+
78+
stack_search(start = deque([func]),
79+
expand=_expand,
80+
mode='bfs')
81+
return retval
6182

6283
def _filter_parents(self, var, parents):
6384
"""Get direct parents of a var, as strings"""

0 commit comments

Comments
 (0)