|
4 | 4 | from itertools import zip_longest
|
5 | 5 |
|
6 | 6 | from pymc import SymbolicRandomVariable
|
| 7 | +from pymc.model.fgraph import ModelVar |
7 | 8 | from pytensor.compile import SharedVariable
|
8 | 9 | from pytensor.graph import Constant, Variable, ancestors
|
9 | 10 | from pytensor.graph.basic import io_toposort
|
@@ -35,12 +36,12 @@ def static_shape_ancestors(vars):
|
35 | 36 |
|
36 | 37 | def find_conditional_input_rvs(output_rvs, all_rvs):
|
37 | 38 | """Find conditionally indepedent input RVs."""
|
38 |
| - blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] |
39 |
| - blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) |
| 39 | + other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] |
| 40 | + blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) |
40 | 41 | return [
|
41 | 42 | var
|
42 | 43 | for var in ancestors(output_rvs, blockers=blockers)
|
43 |
| - if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable)) |
| 44 | + if var in other_rvs |
44 | 45 | ]
|
45 | 46 |
|
46 | 47 |
|
@@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
|
141 | 142 | # None of the inputs are related to the batch_axes of the input_vars
|
142 | 143 | continue
|
143 | 144 |
|
| 145 | + elif isinstance(node.op, ModelVar): |
| 146 | + var_dims[node.outputs[0]] = inputs_dims[0] |
| 147 | + |
144 | 148 | elif isinstance(node.op, DimShuffle):
|
145 | 149 | [input_dims] = inputs_dims
|
146 | 150 | output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
|
|
0 commit comments