Skip to content

Commit a4e9fba

Browse files
michaelosthegetwiecki
authored andcommitted
Refactor test to use InferenceData
1 parent 0a8d1d4 commit a4e9fba

File tree

1 file changed

+61
-8
lines changed

1 file changed

+61
-8
lines changed

pymc3/tests/test_data_container.py

+61-8
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818

1919
from aesara import shared
2020
from aesara.tensor.sharedvar import ScalarSharedVariable
21+
from aesara.tensor.var import TensorVariable
2122

2223
import pymc3 as pm
2324

25+
from pymc3.aesaraf import floatX
2426
from pymc3.distributions import logpt
27+
from pymc3.exceptions import ShapeError
2528
from pymc3.tests.helpers import SeededTest
2629

2730

@@ -159,22 +162,40 @@ def test_shared_data_as_rv_input(self):
159162
"""
160163
with pm.Model() as m:
161164
x = pm.Data("x", [1.0, 2.0, 3.0])
162-
_ = pm.Normal("y", mu=x, size=3)
163-
trace = pm.sample(
164-
chains=1, return_inferencedata=False, compute_convergence_checks=False
165+
y = pm.Normal("y", mu=x, size=(2, 3))
166+
assert y.eval().shape == (2, 3)
167+
idata = pm.sample(
168+
chains=1,
169+
tune=500,
170+
draws=550,
171+
return_inferencedata=True,
172+
compute_convergence_checks=False,
165173
)
174+
samples = idata.posterior["y"]
175+
assert samples.shape == (1, 550, 2, 3)
166176

167177
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
168-
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)
178+
np.testing.assert_allclose(
179+
np.array([1.0, 2.0, 3.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
180+
)
169181

170182
with m:
171183
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
172-
trace = pm.sample(
173-
chains=1, return_inferencedata=False, compute_convergence_checks=False
184+
assert y.eval().shape == (2, 3)
185+
idata = pm.sample(
186+
chains=1,
187+
tune=500,
188+
draws=620,
189+
return_inferencedata=True,
190+
compute_convergence_checks=False,
174191
)
192+
samples = idata.posterior["y"]
193+
assert samples.shape == (1, 620, 2, 3)
175194

176195
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
177-
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
196+
np.testing.assert_allclose(
197+
np.array([2.0, 4.0, 6.0]), samples.mean(("chain", "draw", "y_dim_0")), atol=1e-1
198+
)
178199

179200
def test_shared_scalar_as_rv_input(self):
180201
# See https://github.com/pymc-devs/pymc3/issues/3139
@@ -217,7 +238,7 @@ def test_set_data_to_non_data_container_variables(self):
217238
)
218239
with pytest.raises(TypeError) as error:
219240
pm.set_data({"beta": [1.1, 2.2, 3.3]}, model=model)
220-
error.match("defined as `pymc3.Data` inside the model")
241+
error.match("The variable `beta` must be a `SharedVariable`")
221242

222243
@pytest.mark.xfail(reason="Depends on ModelGraph")
223244
def test_model_to_graphviz_for_model_with_data_container(self):
@@ -283,6 +304,38 @@ def test_explicit_coords(self):
283304
assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable)
284305
assert pmodel.dim_lengths["columns"].eval() == 7
285306

307+
def test_symbolic_coords(self):
308+
"""
309+
In v4 dimensions can be created without passing coordinate values.
310+
Their lengths are then automatically linked to the corresponding Tensor dimension.
311+
"""
312+
with pm.Model() as pmodel:
313+
intensity = pm.Data("intensity", np.ones((2, 3)), dims=("row", "column"))
314+
assert "row" in pmodel.dim_lengths
315+
assert "column" in pmodel.dim_lengths
316+
assert isinstance(pmodel.dim_lengths["row"], TensorVariable)
317+
assert isinstance(pmodel.dim_lengths["column"], TensorVariable)
318+
assert pmodel.dim_lengths["row"].eval() == 2
319+
assert pmodel.dim_lengths["column"].eval() == 3
320+
321+
intensity.set_value(floatX(np.ones((4, 5))))
322+
assert pmodel.dim_lengths["row"].eval() == 4
323+
assert pmodel.dim_lengths["column"].eval() == 5
324+
325+
def test_no_resize_of_implied_dimensions(self):
326+
with pm.Model() as pmodel:
327+
# Imply a dimension through RV params
328+
pm.Normal("n", mu=[1, 2, 3], dims="city")
329+
# _Use_ the dimension for a data variable
330+
inhabitants = pm.Data("inhabitants", [100, 200, 300], dims="city")
331+
332+
# Attempting to re-size the dimension through the data variable would
333+
# cause shape problems in InferenceData conversion, because the RV remains (3,).
334+
with pytest.raises(
335+
ShapeError, match="was initialized from 'n' which is not a shared variable"
336+
):
337+
pmodel.set_data("inhabitants", [1, 2, 3, 4])
338+
286339
def test_implicit_coords_series(self):
287340
ser_sales = pd.Series(
288341
data=np.random.randint(low=0, high=30, size=22),

0 commit comments

Comments
 (0)