Skip to content

Commit ce97def

Browse files
authored
Merge pull request #402 from bioimage-io/always_test_all_weight_formats
always test all weight formats
2 parents 3275a99 + 25a7e00 commit ce97def

File tree

3 files changed

+25
-36
lines changed

3 files changed

+25
-36
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ The model specification and its validation tools can be found at <https://github
124124

125125
## Changelog
126126

127+
### 0.6.8 (to be released)
128+
129+
* testing model inference will now check all weight formats
130+
(previously only the first one for which model adapter creation succeeded had been checked)
131+
127132
### 0.6.7
128133

129134
* `predict()` argument `inputs` may be sample

bioimageio/core/_resource_tests.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,16 @@ def load_description_and_test(
103103
_test_expected_resource_type(rd, expected_type)
104104

105105
if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
106-
_test_model_inference(rd, weight_format, devices, decimal)
107-
if not isinstance(rd, v0_4.ModelDescr):
108-
_test_model_inference_parametrized(rd, weight_format, devices)
106+
if weight_format is None:
107+
weight_formats: List[WeightsFormat] = [
108+
w for w, we in rd.weights if we is not None
109+
] # pyright: ignore[reportAssignmentType]
110+
else:
111+
weight_formats = [weight_format]
112+
for w in weight_formats:
113+
_test_model_inference(rd, w, devices, decimal)
114+
if not isinstance(rd, v0_4.ModelDescr):
115+
_test_model_inference_parametrized(rd, w, devices)
109116

110117
# TODO: add execution of jupyter notebooks
111118
# TODO: add more tests
@@ -115,7 +122,7 @@ def load_description_and_test(
115122

116123
def _test_model_inference(
117124
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
118-
weight_format: Optional[WeightsFormat],
125+
weight_format: WeightsFormat,
119126
devices: Optional[List[str]],
120127
decimal: int,
121128
) -> None:
@@ -161,11 +168,7 @@ def _test_model_inference(
161168
if error is None
162169
else [
163170
ErrorEntry(
164-
loc=(
165-
("weights",)
166-
if weight_format is None
167-
else ("weights", weight_format)
168-
),
171+
loc=("weights", weight_format),
169172
msg=error,
170173
type="bioimageio.core",
171174
traceback=tb,
@@ -178,7 +181,7 @@ def _test_model_inference(
178181

179182
def _test_model_inference_parametrized(
180183
model: v0_5.ModelDescr,
181-
weight_format: Optional[WeightsFormat],
184+
weight_format: WeightsFormat,
182185
devices: Optional[List[str]],
183186
) -> None:
184187
if not any(
@@ -300,19 +303,15 @@ def get_ns(n: int):
300303

301304
model.validation_summary.add_detail(
302305
ValidationDetail(
303-
name="Run inference for inputs with batch_size:"
304-
+ f" {batch_size} and size parameter n: {n}",
306+
name=f"Run {weight_format} inference for inputs with"
307+
+ f" batch_size: {batch_size} and size parameter n: {n}",
305308
status="passed" if error is None else "failed",
306309
errors=(
307310
[]
308311
if error is None
309312
else [
310313
ErrorEntry(
311-
loc=(
312-
("weights",)
313-
if weight_format is None
314-
else ("weights", weight_format)
315-
),
314+
loc=("weights", weight_format),
316315
msg=error,
317316
type="bioimageio.core",
318317
)
@@ -325,15 +324,11 @@ def get_ns(n: int):
325324
tb = traceback.format_tb(e.__traceback__)
326325
model.validation_summary.add_detail(
327326
ValidationDetail(
328-
name="Run inference for parametrized inputs",
327+
name=f"Run {weight_format} inference for parametrized inputs",
329328
status="failed",
330329
errors=[
331330
ErrorEntry(
332-
loc=(
333-
("weights",)
334-
if weight_format is None
335-
else ("weights", weight_format)
336-
),
331+
loc=("weights", weight_format),
337332
msg=error,
338333
type="bioimageio.core",
339334
traceback=tb,

tests/conftest.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,23 +115,12 @@
115115
"unet2d_nuclei_broad_model",
116116
]
117117
)
118-
ONNX_MODELS = (
119-
[]
120-
if onnxruntime is None
121-
else [
122-
"hpa_densenet",
123-
"unet2d_multi_tensor",
124-
"unet2d_nuclei_broad_model",
125-
]
126-
)
118+
ONNX_MODELS = [] if onnxruntime is None else ["hpa_densenet"]
127119
TENSORFLOW_MODELS = (
128120
[]
129121
if tensorflow is None
130122
else (
131-
[
132-
"hpa_densenet",
133-
"stardist",
134-
]
123+
["stardist"]
135124
if tf_major_version == 1
136125
else [
137126
"unet2d_keras_tf2",

0 commit comments

Comments
 (0)