Skip to content

Commit 08cbb33

Browse files
authored
Merge branch 'v4' into reintro_shape
2 parents ab5f44f + 0970af0 commit 08cbb33

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

pymc3/aesaraf.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,13 @@ def pandas_to_array(data):
8888
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
8989
# typically, but not limited to pandas objects
9090
vals = data.to_numpy()
91-
mask = data.isnull().to_numpy()
91+
null_data = data.isnull()
92+
if hasattr(null_data, "to_numpy"):
93+
# pandas Series
94+
mask = null_data.to_numpy()
95+
else:
96+
# pandas Index
97+
mask = null_data
9298
if mask.any():
9399
# there are missing values
94100
ret = np.ma.MaskedArray(vals, mask)

pymc3/tests/test_aesaraf.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_change_rv_size():
5252
loc = at.as_tensor_variable([1, 2])
5353
rv = normal(loc=loc)
5454
assert rv.ndim == 1
55-
assert rv.eval().shape == (2,)
55+
assert tuple(rv.shape.eval()) == (2,)
5656

5757
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
5858
change_rv_size(rv, new_size=[[2, 3]])
@@ -61,7 +61,7 @@ def test_change_rv_size():
6161

6262
rv_new = change_rv_size(rv, new_size=(3,), expand=True)
6363
assert rv_new.ndim == 2
64-
assert rv_new.eval().shape == (3, 2)
64+
assert tuple(rv_new.shape.eval()) == (3, 2)
6565

6666
# Make sure that the shape used to determine the expanded size doesn't
6767
# depend on the old `RandomVariable`.
@@ -71,7 +71,7 @@ def test_change_rv_size():
7171

7272
rv_newer = change_rv_size(rv_new, new_size=(4,), expand=True)
7373
assert rv_newer.ndim == 3
74-
assert rv_newer.eval().shape == (4, 3, 2)
74+
assert tuple(rv_newer.shape.eval()) == (4, 3, 2)
7575

7676
# Make sure we avoid introducing a `Cast` by converting the new size before
7777
# constructing the new `RandomVariable`
@@ -80,7 +80,19 @@ def test_change_rv_size():
8080
rv_newer = change_rv_size(rv, new_size=new_size, expand=False)
8181
assert rv_newer.ndim == 2
8282
assert isinstance(rv_newer.owner.inputs[1], Constant)
83-
assert rv_newer.eval().shape == (4, 3)
83+
assert tuple(rv_newer.shape.eval()) == (4, 3)
84+
85+
rv = normal(0, 1)
86+
new_size = at.as_tensor(np.array([4, 3], dtype="int32"))
87+
rv_newer = change_rv_size(rv, new_size=new_size, expand=True)
88+
assert rv_newer.ndim == 2
89+
assert tuple(rv_newer.shape.eval()) == (4, 3)
90+
91+
rv = normal(0, 1)
92+
new_size = at.as_tensor(2, dtype="int32")
93+
rv_newer = change_rv_size(rv, new_size=new_size, expand=True)
94+
assert rv_newer.ndim == 1
95+
assert tuple(rv_newer.shape.eval()) == (2,)
8496

8597

8698
class TestBroadcasting:
@@ -436,6 +448,13 @@ def test_pandas_to_array(input_dtype):
436448
assert isinstance(wrapped, TensorVariable)
437449

438450

451+
def test_pandas_to_array_pandas_index():
452+
data = pd.Index([1, 2, 3])
453+
result = pandas_to_array(data)
454+
expected = np.array([1, 2, 3])
455+
np.testing.assert_array_equal(result, expected)
456+
457+
439458
def test_walk_model():
440459
d = at.vector("d")
441460
b = at.vector("b")

pymc3/tests/test_sampling.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -688,10 +688,10 @@ def test_deterministic_of_observed_modified_interface(self):
688688
meas_in_1 = pm.aesaraf.floatX(2 + 4 * np.random.randn(100))
689689
meas_in_2 = pm.aesaraf.floatX(5 + 4 * np.random.randn(100))
690690
with pm.Model() as model:
691-
mu_in_1 = pm.Normal("mu_in_1", 0, 1)
692-
sigma_in_1 = pm.HalfNormal("sd_in_1", 1)
693-
mu_in_2 = pm.Normal("mu_in_2", 0, 1)
694-
sigma_in_2 = pm.HalfNormal("sd__in_2", 1)
691+
mu_in_1 = pm.Normal("mu_in_1", 0, 1, testval=0)
692+
sigma_in_1 = pm.HalfNormal("sd_in_1", 1, testval=1)
693+
mu_in_2 = pm.Normal("mu_in_2", 0, 1, testval=0)
694+
sigma_in_2 = pm.HalfNormal("sd__in_2", 1, testval=1)
695695

696696
in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1)
697697
in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2)

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
aesara>=2.0.8
1+
aesara>=2.0.9
22
arviz>=0.11.2
33
cachetools>=4.2.1
44
dill

0 commit comments

Comments
 (0)