Skip to content

Commit 3177640

Browse files
committed
.WIP refactor MarginalModel
1 parent 5de65eb commit 3177640

File tree

4 files changed

+270
-603
lines changed

4 files changed

+270
-603
lines changed

pymc_experimental/model/marginal/distributions.py

+12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytensor.tensor as pt
55

6+
from pymc.distributions.distribution import _support_point
67
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
78
from pymc.logprob.abstract import MeasurableOp, _logprob
89
from pymc.logprob.basic import conditional_logp, logp
@@ -44,6 +45,17 @@ def support_axes(self) -> tuple[tuple[int]]:
4445
return tuple(support_axes_vars)
4546

4647

48+
@_support_point.register
49+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
50+
rv_idx = rv.owner.outputs.index(rv)
51+
inner_rv = inline_ofg_outputs(op, inputs)[rv_idx]
52+
return _support_point(inner_rv.owner.op, inner_rv, *inner_rv.owner.inputs)
53+
54+
55+
56+
57+
58+
4759
class MarginalFiniteDiscreteRV(MarginalRV):
4860
"""Base class for Marginalized Finite Discrete RVs"""
4961

pymc_experimental/model/marginal/graph_analysis.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import zip_longest
55

66
from pymc import SymbolicRandomVariable
7+
from pymc.model.fgraph import ModelVar
78
from pytensor.compile import SharedVariable
89
from pytensor.graph import Constant, Variable, ancestors
910
from pytensor.graph.basic import io_toposort
@@ -35,12 +36,12 @@ def static_shape_ancestors(vars):
3536

3637
def find_conditional_input_rvs(output_rvs, all_rvs):
3738
"""Find conditionally indepedent input RVs."""
38-
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39-
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
39+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
40+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
4041
return [
4142
var
4243
for var in ancestors(output_rvs, blockers=blockers)
43-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44+
if var in other_rvs
4445
]
4546

4647

@@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141142
# None of the inputs are related to the batch_axes of the input_vars
142143
continue
143144

145+
elif isinstance(node.op, ModelVar):
146+
var_dims[node.outputs[0]] = inputs_dims[0]
147+
144148
elif isinstance(node.op, DimShuffle):
145149
[input_dims] = inputs_dims
146150
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)

0 commit comments

Comments
 (0)