Skip to content

Commit 32822b7

Browse files
Add OS-independent #4652 regression tests
1 parent fbcce93 commit 32822b7

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

Diff for: pymc3/tests/test_ode.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
1415

1516
import aesara
1617
import numpy as np
@@ -168,6 +169,9 @@ def ode_func_5(y, t, p):
168169
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._sens_ic)
169170

170171

172+
@pytest.mark.xfail(
173+
condition=sys.platform == "win32", reason="https://github.com/pymc-devs/aesara/issues/390"
174+
)
171175
def test_logp_scalar_ode():
172176
"""Test the computation of the log probability for these models"""
173177

Diff for: pymc3/tests/test_shape_handling.py

+26
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara
1516
import numpy as np
1617
import pytest
1718

@@ -219,3 +220,28 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
219220
prior = pm.sample_prior_predictive(samples=fixture_sizes)
220221
for rv in RVs:
221222
assert prior[rv.name].shape == size + tuple(rv.distribution.shape)
223+
224+
225+
@pytest.mark.xfail(reason="https://github.com/pymc-devs/aesara/issues/390")
226+
def test_size32_doesnt_break_broadcasting():
227+
size32 = at.constant([1, 10], dtype="int32")
228+
rv = pm.Normal.dist(0, 1, size=size32)
229+
assert rv.broadcastable == (True, False)
230+
231+
232+
def test_observed_with_column_vector():
233+
with pm.Model() as model:
234+
# The `observed` is a broadcastable column vector
235+
obs = at.as_tensor_variable(np.ones((3, 1), dtype=aesara.config.floatX))
236+
assert obs.broadcastable == (False, True)
237+
238+
# Both shapes describe broadcastable volumn vectors
239+
size64 = at.constant([3, 1], dtype="int64")
240+
# But the second shape is upcasted from an int32 vector
241+
cast64 = at.cast(at.constant([3, 1], dtype="int32"), dtype="int64")
242+
243+
pm.Normal("x_size64", mu=0, sd=1, size=size64, observed=obs)
244+
model.logp()
245+
246+
pm.Normal("x_cast64", mu=0, sd=1, size=cast64, observed=obs)
247+
model.logp()

0 commit comments

Comments
 (0)