Skip to content

Commit 14f599e

Browse files
committed
add predict_sample_with_fixed_blocking
1 parent c766f12 commit 14f599e

File tree

2 files changed

+68
-30
lines changed

2 files changed

+68
-30
lines changed

bioimageio/core/_prediction_pipeline.py

+44-26
Original file line numberDiff line numberDiff line change
@@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
179179
self.model_description.id or self.model_description.name
180180
)
181181

182-
def predict_sample_with_blocking(
182+
def predict_sample_with_fixed_blocking(
183183
self,
184184
sample: Sample,
185+
input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
186+
*,
185187
skip_preprocessing: bool = False,
186188
skip_postprocessing: bool = False,
187-
ns: Optional[
188-
Union[
189-
v0_5.ParameterizedSize_N,
190-
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
191-
]
192-
] = None,
193-
batch_size: Optional[int] = None,
194189
) -> Sample:
195-
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
196190
if not skip_preprocessing:
197191
self.apply_preprocessing(sample)
198192

199-
if isinstance(self.model_description, v0_4.ModelDescr):
200-
raise NotImplementedError(
201-
"predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
202-
)
203-
204-
ns = ns or self._default_ns
205-
if isinstance(ns, int):
206-
ns = {
207-
(ipt.id, a.id): ns
208-
for ipt in self.model_description.inputs
209-
for a in ipt.axes
210-
if isinstance(a.size, v0_5.ParameterizedSize)
211-
}
212-
input_block_shape = self.model_description.get_tensor_sizes(
213-
ns, batch_size or self._default_batch_size
214-
).inputs
215-
216193
n_blocks, input_blocks = sample.split_into_blocks(
217194
input_block_shape,
218195
halo=self._default_input_halo,
@@ -239,6 +216,47 @@ def predict_sample_with_blocking(
239216

240217
return predicted_sample
241218

219+
def predict_sample_with_blocking(
220+
self,
221+
sample: Sample,
222+
skip_preprocessing: bool = False,
223+
skip_postprocessing: bool = False,
224+
ns: Optional[
225+
Union[
226+
v0_5.ParameterizedSize_N,
227+
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
228+
]
229+
] = None,
230+
batch_size: Optional[int] = None,
231+
) -> Sample:
232+
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
233+
234+
if isinstance(self.model_description, v0_4.ModelDescr):
235+
raise NotImplementedError(
236+
"`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
237+
+ f" {self.model_description.name}."
238+
+ " Consider using `predict_sample_with_fixed_blocking`"
239+
)
240+
241+
ns = ns or self._default_ns
242+
if isinstance(ns, int):
243+
ns = {
244+
(ipt.id, a.id): ns
245+
for ipt in self.model_description.inputs
246+
for a in ipt.axes
247+
if isinstance(a.size, v0_5.ParameterizedSize)
248+
}
249+
input_block_shape = self.model_description.get_tensor_sizes(
250+
ns, batch_size or self._default_batch_size
251+
).inputs
252+
253+
return self.predict_sample_with_fixed_blocking(
254+
sample,
255+
input_block_shape=input_block_shape,
256+
skip_preprocessing=skip_preprocessing,
257+
skip_postprocessing=skip_postprocessing,
258+
)
259+
242260
# def predict(
243261
# self,
244262
# inputs: Predict_IO,

bioimageio/core/prediction.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313

1414
import xarray as xr
15+
from loguru import logger
1516
from numpy.typing import NDArray
1617
from tqdm import tqdm
1718

@@ -41,6 +42,7 @@ def predict(
4142
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
4243
]
4344
] = None,
45+
input_block_shape: Optional[Mapping[MemberId, Mapping[AxisId, int]]] = None,
4446
skip_preprocessing: bool = False,
4547
skip_postprocessing: bool = False,
4648
save_output_path: Optional[Union[Path, str]] = None,
@@ -53,7 +55,11 @@ def predict(
5355
inputs: the input sample or the named input(s) for this model as a dictionary
5456
sample_id: the sample id.
5557
blocksize_parameter: (optional) tile the input into blocks parametrized by
56-
blocksize according to any parametrized axis sizes defined in the model RDF
58+
blocksize according to any parametrized axis sizes defined in the model RDF.
59+
Note: For a predetermined, fixed block shape use `input_block_shape`
60+
input_block_shape: (optional) tile the input sample tensors into blocks.
61+
Note: For a parameterized block shape, not dealing with the exact block shape,
62+
use `blocksie_parameter`.
5763
skip_preprocessing: flag to skip the model's preprocessing
5864
skip_postprocessing: flag to skip the model's postprocessing
5965
save_output_path: A path with `{member_id}` `{sample_id}` in it
@@ -83,19 +89,33 @@ def predict(
8389
pp.model_description, inputs=inputs, sample_id=sample_id
8490
)
8591

86-
if blocksize_parameter is None:
87-
output = pp.predict_sample_without_blocking(
92+
if input_block_shape is not None:
93+
if blocksize_parameter is not None:
94+
logger.warning(
95+
"ignoring blocksize_parameter={} in favor of input_block_shape={}",
96+
blocksize_parameter,
97+
input_block_shape,
98+
)
99+
100+
output = pp.predict_sample_with_fixed_blocking(
88101
sample,
102+
input_block_shape=input_block_shape,
89103
skip_preprocessing=skip_preprocessing,
90104
skip_postprocessing=skip_postprocessing,
91105
)
92-
else:
106+
elif blocksize_parameter is not None:
93107
output = pp.predict_sample_with_blocking(
94108
sample,
95109
skip_preprocessing=skip_preprocessing,
96110
skip_postprocessing=skip_postprocessing,
97111
ns=blocksize_parameter,
98112
)
113+
else:
114+
output = pp.predict_sample_without_blocking(
115+
sample,
116+
skip_preprocessing=skip_preprocessing,
117+
skip_postprocessing=skip_postprocessing,
118+
)
99119
if save_output_path:
100120
save_sample(save_output_path, output)
101121

0 commit comments

Comments
 (0)