Skip to content

Commit 8cc2629

Browse files
committed
TEST: Refine parameter and assertion precision
1 parent 381ad1c commit 8cc2629

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

nibabel/tests/test_proxy_api.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from nose.tools import (assert_true, assert_false, assert_raises,
5656
assert_equal, assert_not_equal, assert_greater_equal)
5757

58-
from numpy.testing import (assert_almost_equal, assert_array_equal)
58+
from numpy.testing import assert_almost_equal, assert_array_equal, assert_allclose
5959

6060
from ..testing import data_path as DATA_PATH, assert_dt_equal
6161

@@ -143,7 +143,10 @@ def validate_get_scaled(self, pmaker, params):
143143

144144
for dtype in np.sctypes['float'] + np.sctypes['int'] + np.sctypes['uint']:
145145
out = prox.get_scaled(dtype=dtype)
146-
assert_almost_equal(out, params['arr_out'])
146+
# Half-precision is imprecise. Obviously. It's a bad idea, but don't break
147+
# the test over it.
148+
rtol = 1e-03 if dtype == np.float16 else 1e-05
149+
assert_allclose(out, params['arr_out'].astype(out.dtype), rtol=rtol, atol=1e-08)
147150
assert_greater_equal(out.dtype, np.dtype(dtype))
148151
# Shape matches expected shape
149152
assert_equal(out.shape, params['shape'])
@@ -218,8 +221,8 @@ def obj_params(self):
218221
offsets = (self.header_class().get_data_offset(),)
219222
else:
220223
offsets = (0, 16)
221-
slopes = (1., 2., 3.1416) if self.has_slope else (1.,)
222-
inters = (0., 10., 2.7183) if self.has_inter else (0.,)
224+
slopes = (1., 2., float(np.float32(3.1416))) if self.has_slope else (1.,)
225+
inters = (0., 10., float(np.float32(2.7183))) if self.has_inter else (0.,)
223226
for shape, dtype, offset, slope, inter in product(self.shapes,
224227
self.data_dtypes,
225228
offsets,
@@ -263,7 +266,7 @@ def sio_func():
263266
dtype=dtype,
264267
dtype_out=dtype_out,
265268
arr=arr.copy(),
266-
arr_out=arr * slope + inter,
269+
arr_out=arr.astype(dtype_out) * slope + inter,
267270
shape=shape,
268271
offset=offset,
269272
slope=slope,

0 commit comments

Comments
 (0)