Skip to content

Commit 8c81624

Browse files
kyleabeauchampColCarroll
authored andcommitted
Fix test_dist_math tests for float32 (#2268)
1 parent 84db639 commit 8c81624

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pymc3/tests/test_dist_math.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ class TestMvNormalLogp():
125125
def test_logp(self):
126126
np.random.seed(42)
127127

128-
chol_val = np.array([[1, 0.9], [0, 2]])
129-
cov_val = np.dot(chol_val, chol_val.T)
128+
chol_val = floatX(np.array([[1, 0.9], [0, 2]]))
129+
cov_val = floatX(np.dot(chol_val, chol_val.T))
130130
cov = tt.matrix('cov')
131131
cov.tag.test_value = cov_val
132-
delta_val = np.random.randn(5, 2)
132+
delta_val = floatX(np.random.randn(5, 2))
133133
delta = tt.matrix('delta')
134134
delta.tag.test_value = delta_val
135135
expect = stats.multivariate_normal(mean=np.zeros(2), cov=cov_val)
@@ -151,15 +151,15 @@ def func(chol_vec, delta):
151151
cov = tt.dot(chol, chol.T)
152152
return MvNormalLogp()(cov, delta)
153153

154-
chol_vec_val = np.array([0.5, 1., -0.1])
154+
chol_vec_val = floatX(np.array([0.5, 1., -0.1]))
155155

156-
delta_val = np.random.randn(1, 2)
156+
delta_val = floatX(np.random.randn(1, 2))
157157
try:
158158
utt.verify_grad(func, [chol_vec_val, delta_val])
159159
except ValueError as e:
160160
print(e.args[0])
161161

162-
delta_val = np.random.randn(5, 2)
162+
delta_val = floatX(np.random.randn(5, 2))
163163
utt.verify_grad(func, [chol_vec_val, delta_val])
164164

165165
@pytest.mark.skip(reason="Fix in theano not released yet: Theano#5908")

0 commit comments

Comments
 (0)