Skip to content

Commit 1288593

Browse files
Merge pull request #226 from bioimage-io/axes_in_trfs
Fix mode fixed with axes
2 parents dc783e6 + d5c26e8 commit 1288593

File tree

4 files changed

+47
-18
lines changed

4 files changed

+47
-18
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,7 @@ def compute_sample_statistics(
158158
def _compute_tensor_statistics(tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
159159
ret = {}
160160
for measure in measures:
161-
if isinstance(measure, Mean):
162-
v = tensor.mean(dim=measure.axes)
163-
elif isinstance(measure, Std):
164-
v = tensor.std(dim=measure.axes)
165-
elif isinstance(measure, Percentile):
166-
v = tensor.quantile(measure.n / 100.0, dim=measure.axes)
167-
else:
168-
raise NotImplementedError(measure)
169-
170-
ret[measure] = v
161+
ret[measure] = measure.compute(tensor)
171162

172163
return ret
173164

bioimageio/core/prediction_pipeline/_processing.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
from typing_extensions import Literal, get_args # type: ignore
1313

1414

15+
def _get_fixed(
16+
fixed: Union[float, Sequence[float]], tensor: xr.DataArray, axes: Optional[Sequence[str]]
17+
) -> Union[float, xr.DataArray]:
18+
if axes is None:
19+
return fixed
20+
21+
fixed_shape = tuple(s for d, s in tensor.sizes.items() if d not in axes)
22+
fixed_dims = tuple(d for d in tensor.dims if d not in axes)
23+
fixed = np.array(fixed).reshape(fixed_shape)
24+
return xr.DataArray(fixed, dims=fixed_dims)
25+
26+
1527
@dataclass
1628
class Processing:
1729
"""base class for all Pre- and Postprocessing transformations"""
@@ -226,8 +238,8 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
226238
@dataclass
227239
class ZeroMeanUnitVariance(Processing):
228240
mode: Literal["fixed", "per_sample", "per_dataset"] = "per_sample"
229-
mean: Optional[float] = None
230-
std: Optional[float] = None
241+
mean: Optional[Union[float, Sequence[float]]] = None
242+
std: Optional[Union[float, Sequence[float]]] = None
231243
axes: Optional[Sequence[str]] = None
232244
eps: float = 1.0e-6
233245

@@ -247,12 +259,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
247259
axes = None if self.axes is None else tuple(self.axes)
248260
if self.mode == "fixed":
249261
assert self.mean is not None and self.std is not None
250-
mean, std = self.mean, self.std
262+
mean = _get_fixed(self.mean, tensor, axes)
263+
std = _get_fixed(self.std, tensor, axes)
251264
elif self.mode == "per_sample":
252-
if axes:
253-
mean, std = tensor.mean(axes), tensor.std(axes)
254-
else:
255-
mean, std = tensor.mean(), tensor.std()
265+
mean = Mean(axes).compute(tensor)
266+
std = Std(axes).compute(tensor)
256267
elif self.mode == "per_dataset":
257268
mean = self.get_computed_dataset_statistics(self.tensor_name, Mean(axes))
258269
std = self.get_computed_dataset_statistics(self.tensor_name, Std(axes))

bioimageio/core/statistical_measures.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
from dataclasses import dataclass
22
from typing import Optional, Tuple
33

4+
import xarray as xr
5+
46

57
@dataclass(frozen=True)
68
class Measure:
7-
pass
9+
def compute(self, tensor: xr.DataArray):
10+
raise NotImplementedError(self.__class__.__name__)
811

912

1013
@dataclass(frozen=True)
1114
class Mean(Measure):
1215
axes: Optional[Tuple[str]] = None
1316

17+
def compute(self, tensor: xr.DataArray) -> xr.DataArray:
18+
return tensor.mean(dim=self.axes)
19+
1420

1521
@dataclass(frozen=True)
1622
class Std(Measure):
1723
axes: Optional[Tuple[str]] = None
1824

25+
def compute(self, tensor: xr.DataArray) -> xr.DataArray:
26+
return tensor.std(dim=self.axes)
27+
1928

2029
@dataclass(frozen=True)
2130
class Percentile(Measure):
@@ -25,3 +34,6 @@ class Percentile(Measure):
2534
def __post_init__(self):
2635
assert self.n >= 0
2736
assert self.n <= 100
37+
38+
def compute(self, tensor: xr.DataArray) -> xr.DataArray:
39+
return tensor.quantile(self.n / 100.0, dim=self.axes)

tests/prediction_pipeline/test_preprocessing.py

+15
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ def test_zero_mean_unit_variance_preprocessing():
4343
xr.testing.assert_allclose(expected, result)
4444

4545

46+
def test_zero_mean_unit_variance_preprocessing_fixed():
47+
from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance
48+
49+
preprocessing = ZeroMeanUnitVariance(
50+
"data_name", mode="fixed", axes=["y"], mean=[1, 4, 7], std=[0.81650, 0.81650, 0.81650]
51+
)
52+
data = xr.DataArray(np.arange(9).reshape((1, 1, 3, 3)), dims=("b", "c", "x", "y"))
53+
expected = xr.DataArray(
54+
np.array([[-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743]])[None, None],
55+
dims=("b", "c", "x", "y"),
56+
)
57+
result = preprocessing(data)
58+
xr.testing.assert_allclose(expected, result)
59+
60+
4661
def test_zero_mean_unit_across_axes():
4762
from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance
4863

0 commit comments

Comments
 (0)