Skip to content

Commit db331df

Browse files
authored
Merge pull request #418 from melisande-c/rel_tol_validation
Use relative and absolute tolerance for validating model outputs
2 parents 9e1e1fe + 8cc219d commit db331df

File tree

1 file changed

+66
-7
lines changed

1 file changed

+66
-7
lines changed

bioimageio/core/_resource_tests.py

+66-7
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,18 @@ def test_model(
3535
source: Union[v0_5.ModelDescr, PermissiveFileSource],
3636
weight_format: Optional[WeightsFormat] = None,
3737
devices: Optional[List[str]] = None,
38-
decimal: int = 4,
38+
absolute_tolerance: float = 1.5e-4,
39+
relative_tolerance: float = 1e-4,
40+
decimal: Optional[int] = None,
3941
) -> ValidationSummary:
4042
"""Test model inference"""
43+
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
4144
return test_description(
4245
source,
4346
weight_format=weight_format,
4447
devices=devices,
48+
absolute_tolerance=absolute_tolerance,
49+
relative_tolerance=relative_tolerance,
4550
decimal=decimal,
4651
expected_type="model",
4752
)
@@ -53,15 +58,20 @@ def test_description(
5358
format_version: Union[Literal["discover", "latest"], str] = "discover",
5459
weight_format: Optional[WeightsFormat] = None,
5560
devices: Optional[List[str]] = None,
56-
decimal: int = 4,
61+
absolute_tolerance: float = 1.5e-4,
62+
relative_tolerance: float = 1e-4,
63+
decimal: Optional[int] = None,
5764
expected_type: Optional[str] = None,
5865
) -> ValidationSummary:
5966
"""Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
67+
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
6068
rd = load_description_and_test(
6169
source,
6270
format_version=format_version,
6371
weight_format=weight_format,
6472
devices=devices,
73+
absolute_tolerance=absolute_tolerance,
74+
relative_tolerance=relative_tolerance,
6575
decimal=decimal,
6676
expected_type=expected_type,
6777
)
@@ -74,10 +84,13 @@ def load_description_and_test(
7484
format_version: Union[Literal["discover", "latest"], str] = "discover",
7585
weight_format: Optional[WeightsFormat] = None,
7686
devices: Optional[List[str]] = None,
77-
decimal: int = 4,
87+
absolute_tolerance: float = 1.5e-4,
88+
relative_tolerance: float = 1e-4,
89+
decimal: Optional[int] = None,
7890
expected_type: Optional[str] = None,
7991
) -> Union[ResourceDescr, InvalidDescr]:
8092
"""Test RDF dynamically, e.g. model inference of test inputs"""
93+
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
8194
if (
8295
isinstance(source, ResourceDescrBase)
8396
and format_version != "discover"
@@ -110,7 +123,9 @@ def load_description_and_test(
110123
else:
111124
weight_formats = [weight_format]
112125
for w in weight_formats:
113-
_test_model_inference(rd, w, devices, decimal)
126+
_test_model_inference(
127+
rd, w, devices, absolute_tolerance, relative_tolerance, decimal
128+
)
114129
if not isinstance(rd, v0_4.ModelDescr):
115130
_test_model_inference_parametrized(rd, w, devices)
116131

@@ -124,12 +139,21 @@ def _test_model_inference(
124139
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
125140
weight_format: WeightsFormat,
126141
devices: Optional[List[str]],
127-
decimal: int,
142+
absolute_tolerance: float,
143+
relative_tolerance: float,
144+
decimal: Optional[int],
128145
) -> None:
129146
test_name = "Reproduce test outputs from test inputs"
130147
logger.info("starting '{}'", test_name)
131148
error: Optional[str] = None
132149
tb: List[str] = []
150+
151+
precision_args = _handle_legacy_precision_args(
152+
absolute_tolerance=absolute_tolerance,
153+
relative_tolerance=relative_tolerance,
154+
decimal=decimal,
155+
)
156+
133157
try:
134158
inputs = get_test_inputs(model)
135159
expected = get_test_outputs(model)
@@ -149,8 +173,11 @@ def _test_model_inference(
149173
error = "Output tensors for test case may not be None"
150174
break
151175
try:
152-
np.testing.assert_array_almost_equal(
153-
res.data, exp.data, decimal=decimal
176+
np.testing.assert_allclose(
177+
res.data,
178+
exp.data,
179+
rtol=precision_args["relative_tolerance"],
180+
atol=precision_args["absolute_tolerance"],
154181
)
155182
except AssertionError as e:
156183
error = f"Output and expected output disagree:\n {e}"
@@ -361,6 +388,38 @@ def _test_expected_resource_type(
361388
)
362389

363390

391+
def _handle_legacy_precision_args(
392+
absolute_tolerance: float, relative_tolerance: float, decimal: Optional[int]
393+
) -> Dict[str, float]:
394+
"""
395+
Transform the precision arguments to conform with the current implementation.
396+
397+
If the deprecated `decimal` argument is used it overrides the new behaviour with
398+
the old behaviour.
399+
"""
400+
# Already conforms with current implementation
401+
if decimal is None:
402+
return {
403+
"absolute_tolerance": absolute_tolerance,
404+
"relative_tolerance": relative_tolerance,
405+
}
406+
407+
warnings.warn(
408+
"The argument `decimal` has been depricated in favour of "
409+
+ "`relative_tolerance` and `absolute_tolerance`, with different validation "
410+
+ "logic, using `numpy.testing.assert_allclose, see "
411+
+ "'https://numpy.org/doc/stable/reference/generated/"
412+
+ "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will "
413+
+ "cause validation to revert to the old behaviour."
414+
)
415+
# decimal overrides new behaviour,
416+
# have to convert the params to emulate old behaviour
417+
return {
418+
"absolute_tolerance": 1.5 * 10 ** (-decimal),
419+
"relative_tolerance": 0,
420+
}
421+
422+
364423
# def debug_model(
365424
# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
366425
# *,

0 commit comments

Comments
 (0)