@@ -35,13 +35,18 @@ def test_model(
35
35
source : Union [v0_5 .ModelDescr , PermissiveFileSource ],
36
36
weight_format : Optional [WeightsFormat ] = None ,
37
37
devices : Optional [List [str ]] = None ,
38
- decimal : int = 4 ,
38
+ absolute_tolerance : float = 1.5e-4 ,
39
+ relative_tolerance : float = 1e-4 ,
40
+ decimal : Optional [int ] = None ,
39
41
) -> ValidationSummary :
40
42
"""Test model inference"""
43
+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
41
44
return test_description (
42
45
source ,
43
46
weight_format = weight_format ,
44
47
devices = devices ,
48
+ absolute_tolerance = absolute_tolerance ,
49
+ relative_tolerance = relative_tolerance ,
45
50
decimal = decimal ,
46
51
expected_type = "model" ,
47
52
)
@@ -53,15 +58,20 @@ def test_description(
53
58
format_version : Union [Literal ["discover" , "latest" ], str ] = "discover" ,
54
59
weight_format : Optional [WeightsFormat ] = None ,
55
60
devices : Optional [List [str ]] = None ,
56
- decimal : int = 4 ,
61
+ absolute_tolerance : float = 1.5e-4 ,
62
+ relative_tolerance : float = 1e-4 ,
63
+ decimal : Optional [int ] = None ,
57
64
expected_type : Optional [str ] = None ,
58
65
) -> ValidationSummary :
59
66
"""Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
67
+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
60
68
rd = load_description_and_test (
61
69
source ,
62
70
format_version = format_version ,
63
71
weight_format = weight_format ,
64
72
devices = devices ,
73
+ absolute_tolerance = absolute_tolerance ,
74
+ relative_tolerance = relative_tolerance ,
65
75
decimal = decimal ,
66
76
expected_type = expected_type ,
67
77
)
@@ -74,10 +84,13 @@ def load_description_and_test(
74
84
format_version : Union [Literal ["discover" , "latest" ], str ] = "discover" ,
75
85
weight_format : Optional [WeightsFormat ] = None ,
76
86
devices : Optional [List [str ]] = None ,
77
- decimal : int = 4 ,
87
+ absolute_tolerance : float = 1.5e-4 ,
88
+ relative_tolerance : float = 1e-4 ,
89
+ decimal : Optional [int ] = None ,
78
90
expected_type : Optional [str ] = None ,
79
91
) -> Union [ResourceDescr , InvalidDescr ]:
80
92
"""Test RDF dynamically, e.g. model inference of test inputs"""
93
+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
81
94
if (
82
95
isinstance (source , ResourceDescrBase )
83
96
and format_version != "discover"
@@ -110,7 +123,9 @@ def load_description_and_test(
110
123
else :
111
124
weight_formats = [weight_format ]
112
125
for w in weight_formats :
113
- _test_model_inference (rd , w , devices , decimal )
126
+ _test_model_inference (
127
+ rd , w , devices , absolute_tolerance , relative_tolerance , decimal
128
+ )
114
129
if not isinstance (rd , v0_4 .ModelDescr ):
115
130
_test_model_inference_parametrized (rd , w , devices )
116
131
@@ -124,12 +139,21 @@ def _test_model_inference(
124
139
model : Union [v0_4 .ModelDescr , v0_5 .ModelDescr ],
125
140
weight_format : WeightsFormat ,
126
141
devices : Optional [List [str ]],
127
- decimal : int ,
142
+ absolute_tolerance : float ,
143
+ relative_tolerance : float ,
144
+ decimal : Optional [int ],
128
145
) -> None :
129
146
test_name = "Reproduce test outputs from test inputs"
130
147
logger .info ("starting '{}'" , test_name )
131
148
error : Optional [str ] = None
132
149
tb : List [str ] = []
150
+
151
+ precision_args = _handle_legacy_precision_args (
152
+ absolute_tolerance = absolute_tolerance ,
153
+ relative_tolerance = relative_tolerance ,
154
+ decimal = decimal ,
155
+ )
156
+
133
157
try :
134
158
inputs = get_test_inputs (model )
135
159
expected = get_test_outputs (model )
@@ -149,8 +173,11 @@ def _test_model_inference(
149
173
error = "Output tensors for test case may not be None"
150
174
break
151
175
try :
152
- np .testing .assert_array_almost_equal (
153
- res .data , exp .data , decimal = decimal
176
+ np .testing .assert_allclose (
177
+ res .data ,
178
+ exp .data ,
179
+ rtol = precision_args ["relative_tolerance" ],
180
+ atol = precision_args ["absolute_tolerance" ],
154
181
)
155
182
except AssertionError as e :
156
183
error = f"Output and expected output disagree:\n { e } "
@@ -361,6 +388,38 @@ def _test_expected_resource_type(
361
388
)
362
389
363
390
391
+ def _handle_legacy_precision_args (
392
+ absolute_tolerance : float , relative_tolerance : float , decimal : Optional [int ]
393
+ ) -> Dict [str , float ]:
394
+ """
395
+ Transform the precision arguments to conform with the current implementation.
396
+
397
+ If the deprecated `decimal` argument is used it overrides the new behaviour with
398
+ the old behaviour.
399
+ """
400
+ # Already conforms with current implementation
401
+ if decimal is None :
402
+ return {
403
+ "absolute_tolerance" : absolute_tolerance ,
404
+ "relative_tolerance" : relative_tolerance ,
405
+ }
406
+
407
+ warnings .warn (
408
+ "The argument `decimal` has been depricated in favour of "
409
+ + "`relative_tolerance` and `absolute_tolerance`, with different validation "
410
+ + "logic, using `numpy.testing.assert_allclose, see "
411
+ + "'https://numpy.org/doc/stable/reference/generated/"
412
+ + "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will "
413
+ + "cause validation to revert to the old behaviour."
414
+ )
415
+ # decimal overrides new behaviour,
416
+ # have to convert the params to emulate old behaviour
417
+ return {
418
+ "absolute_tolerance" : 1.5 * 10 ** (- decimal ),
419
+ "relative_tolerance" : 0 ,
420
+ }
421
+
422
+
364
423
# def debug_model(
365
424
# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
366
425
# *,
0 commit comments