1
1
import itertools
2
+ from collections import deque
3
+ from typing import Iterator , Optional
2
4
3
- from theano .gof .graph import ancestors
5
+ from theano .gof .graph import ancestors , stack_search
4
6
from theano .compile import SharedVariable
7
+ from theano .tensor import Tensor
5
8
6
9
from .util import get_default_varnames
7
10
import pymc3 as pm
8
11
12
+ # this is a placeholder for a better characterization of the type
13
+ # of variables in a model.
14
+ RV = Tensor
15
+
9
16
10
17
def powerset (iterable ):
11
18
"""All *nonempty* subsets of an iterable.
@@ -27,7 +34,7 @@ def __init__(self, model):
27
34
self ._deterministics = None
28
35
29
36
def get_deterministics (self , var ):
30
- """Compute the deterministic nodes of the graph"""
37
+ """Compute the deterministic nodes of the graph, **not** including var itself. """
31
38
deterministics = []
32
39
attrs = ('transformed' , 'logpt' )
33
40
for v in self .var_list :
@@ -40,24 +47,38 @@ def _ancestors(self, var, func, blockers=None):
40
47
return set ([j for j in ancestors ([func ], blockers = blockers ) if j in self .var_list and j != var ])
41
48
42
49
def _get_ancestors (self , var , func ):
43
- """Get all ancestors of a function, doing some accounting for deterministics
44
-
45
- Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
46
- return only the inputs *to the deterministic*. However, if we pass in the
47
- deterministic as a blocker, it will skip those nodes.
50
+ """Get all ancestors of a function, doing some accounting for deterministics.
48
51
"""
49
- deterministics = self .get_deterministics (var )
52
+
53
+ # this contains all of the variables in the model EXCEPT var...
54
+ vars : List [var ] = set (self .var_list )
55
+ vars .remove (var )
56
+
50
57
upstream = self ._ancestors (var , func )
51
58
52
59
# Usual case
53
60
if upstream == self ._ancestors (var , func , blockers = upstream ):
54
61
return upstream
55
62
else : # deterministic accounting
56
- for d in powerset (upstream ):
57
- blocked = self ._ancestors (var , func , blockers = d )
58
- if set (d ) == blocked :
59
- return d
60
- raise RuntimeError ('Could not traverse graph. Consider raising an issue with developers.' )
63
+ blockers = set ()
64
+ retval = set ()
65
+ def _expand (node ) -> Optional [Iterator [Tensor ]]:
66
+ if node in blockers :
67
+ return None
68
+ elif node in vars :
69
+ blockers .add (node )
70
+ retval .add (node )
71
+ return None
72
+ elif node .owner :
73
+ blockers .add (node )
74
+ return reversed (node .owner .inputs )
75
+ else :
76
+ return None
77
+
78
+ stack_search (start = deque ([func ]),
79
+ expand = _expand ,
80
+ mode = 'bfs' )
81
+ return retval
61
82
62
83
def _filter_parents (self , var , parents ):
63
84
"""Get direct parents of a var, as strings"""
0 commit comments