Skip to content

Commit 5a44793

Browse files
committed
Make rv_size_is_none more robust
1 parent 2db28f0 commit 5a44793

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pymc/distributions/shape_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import numpy as np
2424

25-
from aesara.graph.basic import Constant, Variable
25+
from aesara.graph.basic import Variable
2626
from aesara.tensor.var import TensorVariable
2727
from typing_extensions import TypeAlias
2828

@@ -618,4 +618,4 @@ def find_size(
618618

619619
def rv_size_is_none(size: Variable) -> bool:
620620
"""Check wether an rv size is None (ie., at.Constant([]))"""
621-
return isinstance(size, Constant) and size.data.size == 0
621+
return size.type.shape == (0,)

pymc/tests/test_distributions_moments.py

+3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def test_rv_size_is_none():
133133
rv = Normal.dist(0, 1, size=None)
134134
assert rv_size_is_none(rv.owner.inputs[1])
135135

136+
rv = Normal.dist(0, 1, size=())
137+
assert rv_size_is_none(rv.owner.inputs[1])
138+
136139
rv = Normal.dist(0, 1, size=1)
137140
assert not rv_size_is_none(rv.owner.inputs[1])
138141

0 commit comments

Comments
 (0)