|
21 | 21 | from pytensor.graph.replace import graph_replace, vectorize_graph
|
22 | 22 | from pytensor.scan import map as scan_map
|
23 | 23 | from pytensor.tensor import TensorType, TensorVariable
|
24 |
| -from pytensor.tensor.elemwise import Elemwise |
| 24 | +from pytensor.tensor.elemwise import DimShuffle, Elemwise |
25 | 25 | from pytensor.tensor.shape import Shape
|
26 | 26 | from pytensor.tensor.special import log_softmax
|
27 | 27 |
|
@@ -598,7 +598,18 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
|
598 | 598 | fg = FunctionGraph(outputs=output_rvs, clone=False)
|
599 | 599 |
|
600 | 600 | non_elemwise_blockers = [
|
601 |
| - o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs |
| 601 | + o |
| 602 | + for node in fg.apply_nodes |
| 603 | + if not ( |
| 604 | + isinstance(node.op, Elemwise) |
| 605 | + # Allow expand_dims on the left |
| 606 | + or ( |
| 607 | + isinstance(node.op, DimShuffle) |
| 608 | + and not node.op.drop |
| 609 | + and node.op.shuffle == sorted(node.op.shuffle) |
| 610 | + ) |
| 611 | + ) |
| 612 | + for o in node.outputs |
602 | 613 | ]
|
603 | 614 | blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
|
604 | 615 | blockers = [var for var in blocker_candidates if var not in output_rvs]
|
@@ -698,16 +709,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
|
698 | 709 |
|
699 | 710 | def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
|
700 | 711 | op = rv.owner.op
|
| 712 | + dist_params = rv.owner.op.dist_params(rv.owner) |
701 | 713 | if isinstance(op, Bernoulli):
|
702 | 714 | return (0, 1)
|
703 | 715 | elif isinstance(op, Categorical):
|
704 |
| - p_param = rv.owner.inputs[3] |
| 716 | + [p_param] = dist_params |
705 | 717 | return tuple(range(pt.get_vector_length(p_param)))
|
706 | 718 | elif isinstance(op, DiscreteUniform):
|
707 |
| - lower, upper = constant_fold(rv.owner.inputs[3:]) |
| 719 | + lower, upper = constant_fold(dist_params) |
708 | 720 | return tuple(np.arange(lower, upper + 1))
|
709 | 721 | elif isinstance(op, DiscreteMarkovChain):
|
710 |
| - P = rv.owner.inputs[0] |
| 722 | + P, *_ = dist_params |
711 | 723 | return tuple(range(pt.get_vector_length(P[-1])))
|
712 | 724 |
|
713 | 725 | raise NotImplementedError(f"Cannot compute domain for op {op}")
|
@@ -827,11 +839,15 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
|
827 | 839 | # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
|
828 | 840 | # We do it entirely in logs, though.
|
829 | 841 |
|
830 |
| - # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under |
831 |
| - # the initial distribution. This is robust to everything the user can throw at it. |
832 |
| - batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")( |
833 |
| - batch_chain_value[..., 0] |
834 |
| - ) |
| 842 | + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) |
| 843 | + # under the initial distribution. This is robust to everything the user can throw at it. |
| 844 | + init_dist_value = init_dist_.type() |
| 845 | + logp_init_dist = logp(init_dist_, init_dist_value) |
| 846 | + # There is a degerate batch dim for lags=1 (the only supported case), |
| 847 | + # that we have to work around, by expanding the batch value and then squeezing it out of the logp |
| 848 | + batch_logp_init_dist = vectorize_graph( |
| 849 | + logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} |
| 850 | + ).squeeze(1) |
835 | 851 | log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
|
836 | 852 |
|
837 | 853 | def step_alpha(logp_emission, log_alpha, log_P):
|
|
0 commit comments