|
1 | 1 | import itertools
|
2 | 2 | from collections import deque
|
3 |
| -from typing import Iterator, Optional |
| 3 | +from typing import Iterator, Optional, MutableSet |
4 | 4 |
|
5 |
| -from theano.gof.graph import ancestors, stack_search |
| 5 | +from theano.gof.graph import stack_search |
6 | 6 | from theano.compile import SharedVariable
|
7 | 7 | from theano.tensor import Tensor
|
8 | 8 |
|
@@ -42,19 +42,15 @@ def get_deterministics(self, var):
|
42 | 42 | deterministics.append(v)
|
43 | 43 | return deterministics
|
44 | 44 |
|
45 |
| - def _ancestors(self, var, func, blockers=None): |
46 |
| - """Get ancestors of a function that are also named PyMC3 variables""" |
47 |
| - return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var]) |
48 |
| - |
49 |
| - def _get_ancestors(self, var, func): |
| 45 | + def _get_ancestors(self, var, func) -> MutableSet[RV]: |
50 | 46 | """Get all ancestors of a function, doing some accounting for deterministics.
|
51 | 47 | """
|
52 | 48 |
|
53 | 49 | # this contains all of the variables in the model EXCEPT var...
|
54 |
| - vars: List[var] = set(self.var_list) |
| 50 | + vars: MutableSet[RV] = set(self.var_list) |
55 | 51 | vars.remove(var)
|
56 | 52 |
|
57 |
| - blockers = set() |
| 53 | + blockers: MutableSet[RV] = set() |
58 | 54 | retval = set()
|
59 | 55 | def _expand(node) -> Optional[Iterator[Tensor]]:
|
60 | 56 | if node in blockers:
|
|
0 commit comments