Skip to content

Commit 738c9de

Browse files
michaelosthegetwiecki
authored andcommitted
Add OS-independent #4652 regression tests
1 parent f8e1a81 commit 738c9de

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

Diff for: pymc3/tests/test_ode.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
1415

1516
import aesara
1617
import numpy as np
@@ -168,6 +169,9 @@ def ode_func_5(y, t, p):
168169
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._sens_ic)
169170

170171

172+
@pytest.mark.xfail(
173+
condition=sys.platform == "win32", reason="https://github.com/pymc-devs/aesara/issues/390"
174+
)
171175
def test_logp_scalar_ode():
172176
"""Test the computation of the log probability for these models"""
173177

Diff for: pymc3/tests/test_shape_handling.py

+30
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara
1516
import numpy as np
1617
import pytest
1718

@@ -219,3 +220,32 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
219220
prior = pm.sample_prior_predictive(samples=fixture_sizes)
220221
for rv in RVs:
221222
assert prior[rv.name].shape == size + tuple(rv.distribution.shape)
223+
224+
225+
@pytest.mark.xfail(reason="https://github.com/pymc-devs/aesara/issues/390")
226+
def test_size32_doesnt_break_broadcasting():
227+
size32 = at.constant([1, 10], dtype="int32")
228+
rv = pm.Normal.dist(0, 1, size=size32)
229+
assert rv.broadcastable == (True, False)
230+
231+
232+
def test_observed_with_column_vector():
233+
"""This test is related to https://github.com/pymc-devs/aesara/issues/390 which breaks
234+
broadcastability of column-vector RVs. This unexpected change in type can lead to
235+
incompatibilities during graph rewriting for model.logp evaluation.
236+
"""
237+
with pm.Model() as model:
238+
# The `observed` is a broadcastable column vector
239+
obs = at.as_tensor_variable(np.ones((3, 1), dtype=aesara.config.floatX))
240+
assert obs.broadcastable == (False, True)
241+
242+
# Both shapes describe broadcastable volumn vectors
243+
size64 = at.constant([3, 1], dtype="int64")
244+
# But the second shape is upcasted from an int32 vector
245+
cast64 = at.cast(at.constant([3, 1], dtype="int32"), dtype="int64")
246+
247+
pm.Normal("x_size64", mu=0, sd=1, size=size64, observed=obs)
248+
model.logp()
249+
250+
pm.Normal("x_cast64", mu=0, sd=1, size=cast64, observed=obs)
251+
model.logp()

0 commit comments

Comments
 (0)