Skip to content

Commit 05e3c39

Browse files
authored
Merge pull request #3460 from rpgoldman/grapher-dfs
Fix for timeout in graph_model
2 parents d113e41 + ca13c44 commit 05e3c39

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Used `numpy.vectorize` in `distributions.distribution._compile_theano_function`. This enables `sample_prior_predictive` and `sample_posterior_predictive` to ask for tuples of samples instead of just integers. This fixes issue #3422.
1313

1414
### Maintenance
15+
- Fixed an issue in `model_graph` that caused construction of the graph of the model for rendering to hang: replaced a search over the powerset of the nodes with a breadth-first search over the nodes. Fix for #3458.
1516
- All occurances of `sd` as a parameter name have been renamed to `sigma`. `sd` will continue to function for backwards compatibility.
1617
- Made `BrokenPipeError` for parallel sampling more verbose on Windows.
1718
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).

pymc3/model_graph.py

+34-23
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import itertools
2+
from collections import deque
3+
from typing import Iterator, Optional, MutableSet
24

3-
from theano.gof.graph import ancestors
5+
from theano.gof.graph import stack_search
46
from theano.compile import SharedVariable
7+
from theano.tensor import Tensor
58

69
from .util import get_default_varnames
710
import pymc3 as pm
811

12+
# this is a placeholder for a better characterization of the type
13+
# of variables in a model.
14+
RV = Tensor
15+
916

1017
def powerset(iterable):
1118
"""All *nonempty* subsets of an iterable.
@@ -27,37 +34,41 @@ def __init__(self, model):
2734
self._deterministics = None
2835

2936
def get_deterministics(self, var):
30-
"""Compute the deterministic nodes of the graph"""
37+
"""Compute the deterministic nodes of the graph, **not** including var itself."""
3138
deterministics = []
3239
attrs = ('transformed', 'logpt')
3340
for v in self.var_list:
3441
if v != var and all(not hasattr(v, attr) for attr in attrs):
3542
deterministics.append(v)
3643
return deterministics
3744

38-
def _ancestors(self, var, func, blockers=None):
39-
"""Get ancestors of a function that are also named PyMC3 variables"""
40-
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])
45+
def _get_ancestors(self, var, func) -> MutableSet[RV]:
46+
"""Get all ancestors of a function, doing some accounting for deterministics.
47+
"""
4148

42-
def _get_ancestors(self, var, func):
43-
"""Get all ancestors of a function, doing some accounting for deterministics
49+
# this contains all of the variables in the model EXCEPT var...
50+
vars: MutableSet[RV] = set(self.var_list)
51+
vars.remove(var)
52+
53+
blockers: MutableSet[RV] = set()
54+
retval = set()
55+
def _expand(node) -> Optional[Iterator[Tensor]]:
56+
if node in blockers:
57+
return None
58+
elif node in vars:
59+
blockers.add(node)
60+
retval.add(node)
61+
return None
62+
elif node.owner:
63+
blockers.add(node)
64+
return reversed(node.owner.inputs)
65+
else:
66+
return None
4467

45-
Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
46-
return only the inputs *to the deterministic*. However, if we pass in the
47-
deterministic as a blocker, it will skip those nodes.
48-
"""
49-
deterministics = self.get_deterministics(var)
50-
upstream = self._ancestors(var, func)
51-
52-
# Usual case
53-
if upstream == self._ancestors(var, func, blockers=upstream):
54-
return upstream
55-
else: # deterministic accounting
56-
for d in powerset(upstream):
57-
blocked = self._ancestors(var, func, blockers=d)
58-
if set(d) == blocked:
59-
return d
60-
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')
68+
stack_search(start = deque([func]),
69+
expand=_expand,
70+
mode='bfs')
71+
return retval
6172

6273
def _filter_parents(self, var, parents):
6374
"""Get direct parents of a var, as strings"""

0 commit comments

Comments
 (0)