diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index ce6a11d208..3709ef756e 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -62,7 +62,7 @@ is_basic_idx, ) from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType +from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import ( @@ -289,9 +289,12 @@ def find_measurable_index_mixture(fgraph, node): # We don't support (non-scalar) integer array indexing as it can pick repeated values, # but the Mixture logprob assumes all mixture values are independent if any( - indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0 + ( + isinstance(indices, TensorVariable) + and indices.dtype.startswith("int") + and any(not b for b in indices.type.broadcastable) + ) for indices in mixing_indices - if not isinstance(indices, SliceConstant) ): return None