Skip to content

Commit bd81ff2

Browse files
authored
Merge pull request #241 from bioimage-io/dataset_stat
Dataset statistics
2 parents e2a0008 + ded2063 commit bd81ff2

14 files changed

+1312
-344
lines changed

bioimageio/core/prediction.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from copy import deepcopy
44
from itertools import product
55
from pathlib import Path
6-
from typing import Dict, List, Optional, OrderedDict, Sequence, Tuple, Union
6+
from typing import Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union
77

88
import imageio
99
import numpy as np
@@ -150,7 +150,15 @@ def _apply_crop(data, crop):
150150
return data[crop]
151151

152152

153-
def _get_tiling(shape, tile_shape, halo, input_axes):
153+
class TileDef(NamedTuple):
154+
outer: Dict[str, slice]
155+
inner: Dict[str, slice]
156+
local: Dict[str, slice]
157+
158+
159+
def get_tiling(
160+
shape: Sequence[int], tile_shape: Dict[str, int], halo: Dict[str, int], input_axes: Sequence[str]
161+
) -> Iterator[TileDef]:
154162
assert len(shape) == len(input_axes)
155163

156164
shape_ = [sh for sh, ax in zip(shape, input_axes) if ax in "xyz"]
@@ -189,15 +197,15 @@ def _get_tiling(shape, tile_shape, halo, input_axes):
189197
local_tile["b"] = slice(None)
190198
local_tile["c"] = slice(None)
191199

192-
yield outer_tile, inner_tile, local_tile
200+
yield TileDef(outer_tile, inner_tile, local_tile)
193201

194202

195203
def _predict_with_tiling_impl(
196204
prediction_pipeline: PredictionPipeline,
197-
inputs: List[xr.DataArray],
198-
outputs: List[xr.DataArray],
199-
tile_shapes: List[dict],
200-
halos: List[dict],
205+
inputs: Sequence[xr.DataArray],
206+
outputs: Sequence[xr.DataArray],
207+
tile_shapes: Sequence[Dict[str, int]],
208+
halos: Sequence[Dict[str, int]],
201209
verbose: bool = False,
202210
):
203211
if len(inputs) > 1:
@@ -214,7 +222,7 @@ def _predict_with_tiling_impl(
214222
tile_shape = tile_shapes[0]
215223
halo = halos[0]
216224

217-
tiles = _get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims)
225+
tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims)
218226

219227
assert all(isinstance(ax, str) for ax in input_.dims)
220228
input_axes: Tuple[str, ...] = input_.dims # noqa
Original file line numberDiff line numberDiff line change
@@ -1,178 +1,68 @@
1-
import warnings
2-
from collections import defaultdict
3-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type
4-
5-
import xarray as xr
1+
from typing import List, Optional, Sequence, Union
62

73
from bioimageio.core.resource_io import nodes
8-
from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std
9-
from bioimageio.spec.model.raw_nodes import PostprocessingName, PreprocessingName
10-
from ._processing import (
11-
Binarize,
12-
Clip,
13-
EnsureDtype,
14-
Processing,
15-
ScaleLinear,
16-
ScaleMeanVariance,
17-
ScaleRange,
18-
Sigmoid,
19-
ZeroMeanUnitVariance,
20-
)
4+
from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing
5+
from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample
216

227
try:
238
from typing import Literal
249
except ImportError:
2510
from typing_extensions import Literal # type: ignore
2611

27-
KNOWN_PREPROCESSING: Dict[PreprocessingName, Type[Processing]] = {
28-
"binarize": Binarize,
29-
"clip": Clip,
30-
"scale_linear": ScaleLinear,
31-
"scale_range": ScaleRange,
32-
"sigmoid": Sigmoid,
33-
"zero_mean_unit_variance": ZeroMeanUnitVariance,
34-
}
35-
36-
KNOWN_POSTPROCESSING: Dict[PostprocessingName, Type[Processing]] = {
37-
"binarize": Binarize,
38-
"clip": Clip,
39-
"scale_linear": ScaleLinear,
40-
"scale_mean_variance": ScaleMeanVariance,
41-
"scale_range": ScaleRange,
42-
"sigmoid": Sigmoid,
43-
"zero_mean_unit_variance": ZeroMeanUnitVariance,
44-
}
45-
46-
47-
Scope = Literal["sample", "dataset"]
48-
SAMPLE: Literal["sample"] = "sample"
49-
DATASET: Literal["dataset"] = "dataset"
50-
SCOPES: Set[Scope] = {SAMPLE, DATASET}
51-
5212

5313
class CombinedProcessing:
54-
def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTensor]):
55-
self._prep = [
56-
KNOWN_PREPROCESSING[step.name](tensor_name=ipt.name, **step.kwargs)
57-
for ipt in inputs
58-
for step in ipt.preprocessing or []
59-
]
60-
self._post = [
61-
KNOWN_POSTPROCESSING.get(step.name)(tensor_name=out.name, **step.kwargs)
62-
for out in outputs
63-
for step in out.postprocessing or []
64-
]
14+
def __init__(self, tensor_specs: Union[List[nodes.InputTensor], List[nodes.OutputTensor]]):
15+
PRE: Literal["pre"] = "pre"
16+
POST: Literal["post"] = "post"
17+
proc_prefix: Optional[Literal["pre", "post"]] = None
18+
self._procs = []
19+
for t in tensor_specs:
20+
if isinstance(t, nodes.InputTensor):
21+
steps = t.preprocessing or []
22+
if proc_prefix is not None and proc_prefix != PRE:
23+
raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}")
24+
25+
proc_prefix = PRE
26+
elif isinstance(t, nodes.OutputTensor):
27+
steps = t.postprocessing or []
28+
if proc_prefix is not None and proc_prefix != POST:
29+
raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}")
30+
31+
proc_prefix = POST
32+
else:
33+
raise NotImplementedError(t)
34+
35+
for step in steps:
36+
self._procs.append(KNOWN_PROCESSING[proc_prefix][step.name](tensor_name=t.name, **step.kwargs))
6537

6638
# There is a difference between pre-and-postprocessing:
6739
# Pre-processing always returns float32, because its output is consumed by the model.
6840
# Post-processing, however, should return the dtype that is specified in the model spec.
6941
# todo: cast dtype for inputs before preprocessing? or check dtype?
70-
for out in outputs:
71-
self._post.append(EnsureDtype(tensor_name=out.name, dtype=out.data_type))
42+
if proc_prefix == POST:
43+
for t in tensor_specs:
44+
self._procs.append(EnsureDtype(tensor_name=t.name, dtype=t.data_type))
7245

73-
self._req_input_stats = {s: self._collect_required_stats(self._prep, s) for s in SCOPES}
74-
self._req_output_stats = {s: self._collect_required_stats(self._post, s) for s in SCOPES}
75-
if self._req_output_stats[DATASET]:
46+
self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs)
47+
if proc_prefix == POST and self.required_measures[PER_DATASET]:
7648
raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented")
7749

78-
self._computed_dataset_stats: Optional[Dict[str, Dict[Measure, Any]]] = None
79-
80-
self.input_tensor_names = [ipt.name for ipt in inputs]
81-
self.output_tensor_names = [out.name for out in outputs]
82-
assert not any(name in self.output_tensor_names for name in self.input_tensor_names)
83-
assert not any(name in self.input_tensor_names for name in self.output_tensor_names)
84-
85-
@property
86-
def required_input_dataset_statistics(self) -> Dict[str, Set[Measure]]:
87-
return self._req_input_stats[DATASET]
88-
89-
@property
90-
def required_output_dataset_statistics(self) -> Dict[str, Set[Measure]]:
91-
return self._req_output_stats[DATASET]
92-
93-
@property
94-
def computed_dataset_statistics(self) -> Dict[str, Dict[Measure, Any]]:
95-
return self._computed_dataset_stats
96-
97-
def apply_preprocessing(
98-
self, *input_tensors: xr.DataArray
99-
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
100-
assert len(input_tensors) == len(self.input_tensor_names)
101-
tensors = dict(zip(self.input_tensor_names, input_tensors))
102-
sample_stats = self.compute_sample_statistics(tensors, self._req_input_stats[SAMPLE])
103-
for proc in self._prep:
104-
proc.set_computed_sample_statistics(sample_stats)
105-
tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name])
106-
107-
return [tensors[tn] for tn in self.input_tensor_names], sample_stats
50+
self.tensor_names = [t.name for t in tensor_specs]
10851

109-
def apply_postprocessing(
110-
self, *output_tensors: xr.DataArray, input_sample_statistics: Dict[str, Dict[Measure, Any]]
111-
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
112-
assert len(output_tensors) == len(self.output_tensor_names)
113-
tensors = dict(zip(self.output_tensor_names, output_tensors))
114-
sample_stats = {
115-
**input_sample_statistics,
116-
**self.compute_sample_statistics(tensors, self._req_output_stats[SAMPLE]),
117-
}
118-
for proc in self._post:
119-
proc.set_computed_sample_statistics(sample_stats)
120-
tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name])
121-
122-
return [tensors[tn] for tn in self.output_tensor_names], sample_stats
123-
124-
def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
125-
"""
126-
This method sets the externally computed dataset statistics.
127-
Which statistics are expected is specified by the `required_dataset_statistics` property.
128-
"""
129-
# always expect input tensor statistics
130-
for tensor_name, req_measures in self.required_input_dataset_statistics:
131-
comp_measures = computed.get(tensor_name, {})
132-
for req_measure in req_measures:
133-
if req_measure not in comp_measures:
134-
raise ValueError(f"Missing required measure {req_measure} for input tensor {tensor_name}")
135-
136-
# as output tensor statistics may initially not be available, we only warn about their absence
137-
output_statistics_missing = False
138-
for tensor_name, req_measures in self.required_output_dataset_statistics:
139-
comp_measures = computed.get(tensor_name, {})
140-
for req_measure in req_measures:
141-
if req_measure not in comp_measures:
142-
output_statistics_missing = True
143-
warnings.warn(f"Missing required measure {req_measure} for output tensor {tensor_name}")
144-
145-
self._computed_dataset_stats = computed
146-
147-
# set dataset statistics for each processing step
148-
for proc in self._prep:
149-
proc.set_computed_dataset_statistics(self.computed_dataset_statistics)
150-
151-
@classmethod
152-
def compute_sample_statistics(
153-
cls, tensors: Dict[str, xr.DataArray], measures: Dict[str, Set[Measure]]
154-
) -> Dict[str, Dict[Measure, Any]]:
155-
return {tname: cls._compute_tensor_statistics(tensors[tname], ms) for tname, ms in measures.items()}
52+
def apply(self, sample: Sample, computed_measures: ComputedMeasures) -> None:
53+
for proc in self._procs:
54+
proc.set_computed_measures(computed_measures)
55+
sample[proc.tensor_name] = proc.apply(sample[proc.tensor_name])
15656

15757
@staticmethod
158-
def _compute_tensor_statistics(tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
159-
ret = {}
160-
for measure in measures:
161-
ret[measure] = measure.compute(tensor)
162-
163-
return ret
164-
165-
@staticmethod
166-
def _collect_required_stats(proc: Sequence[Processing], scope: Literal["sample", "dataset"]):
167-
stats = defaultdict(set)
58+
def _collect_required_measures(proc: Sequence[Processing]) -> RequiredMeasures:
59+
ret: RequiredMeasures = {PER_SAMPLE: {}, PER_DATASET: {}}
16860
for p in proc:
169-
if scope == SAMPLE:
170-
req = p.get_required_sample_statistics()
171-
elif scope == DATASET:
172-
req = p.get_required_dataset_statistics()
173-
else:
174-
raise ValueError(scope)
175-
for tn, ms in req.items():
176-
stats[tn].update(ms)
61+
for mode, ms_per_mode in p.get_required_measures().items():
62+
for tn, ms_per_tn in ms_per_mode.items():
63+
if tn not in ret[mode]:
64+
ret[mode][tn] = set()
65+
66+
ret[mode][tn].update(ms_per_tn)
17767

178-
return dict(stats)
68+
return ret

0 commit comments

Comments
 (0)