Skip to content

Commit 5a9c1c7

Browse files
committed
do not convert axes ids for proc ops
1 parent bc98d65 commit 5a9c1c7

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

bioimageio/core/proc_ops.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,21 @@
4141
from .tensor import Tensor
4242

4343

44-
def convert_axis_ids(
45-
axes: Union[Sequence[AxisId], v0_4.AxesInCZYX],
44+
def _convert_axis_ids(
45+
axes: v0_4.AxesInCZYX,
4646
mode: Literal["per_sample", "per_dataset"],
4747
) -> Tuple[AxisId, ...]:
4848
if not isinstance(axes, str):
4949
return tuple(axes)
5050

51-
axis_map = dict(b=AxisId("batch"), c=AxisId("channel"), i=AxisId("index"))
5251
if mode == "per_sample":
5352
ret = []
5453
elif mode == "per_dataset":
55-
ret = [AxisId("batch")]
54+
ret = [AxisId("b")]
5655
else:
5756
assert_never(mode)
5857

59-
ret.extend([axis_map.get(a, AxisId(a)) for a in axes])
58+
ret.extend([AxisId(a) for a in axes])
6059
return tuple(ret)
6160

6261

@@ -375,7 +374,7 @@ def from_proc_descr(
375374
member_id: MemberId,
376375
) -> Self:
377376
kwargs = descr.kwargs
378-
axes = _get_axes(descr.kwargs)
377+
_, axes = _get_axes(descr.kwargs)
379378

380379
return cls(
381380
input=member_id,
@@ -395,18 +394,18 @@ def _get_axes(
395394
v0_4.ScaleMeanVarianceKwargs,
396395
v0_5.ScaleMeanVarianceKwargs,
397396
]
398-
) -> Union[Tuple[AxisId, ...], None]:
397+
) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]:
399398
if kwargs.axes is None:
400-
axes = None
399+
return True, None
401400
elif isinstance(kwargs.axes, str):
402-
axes = convert_axis_ids(kwargs.axes, kwargs["mode"])
401+
axes = _convert_axis_ids(kwargs.axes, kwargs["mode"])
402+
return AxisId("b") in axes, axes
403403
elif isinstance(kwargs.axes, collections.abc.Sequence):
404404
axes = tuple(kwargs.axes)
405+
return AxisId("batch") in axes, axes
405406
else:
406407
assert_never(kwargs.axes)
407408

408-
return axes
409-
410409

411410
@dataclass
412411
class ScaleRange(_SimpleOperator):
@@ -458,8 +457,8 @@ def from_proc_descr(
458457
if kwargs.reference_tensor is None
459458
else MemberId(str(kwargs.reference_tensor))
460459
)
461-
axes = _get_axes(descr.kwargs)
462-
if axes is None or AxisId("batch") in axes:
460+
dataset_mode, axes = _get_axes(descr.kwargs)
461+
if dataset_mode:
463462
Percentile = DatasetPercentile
464463
else:
465464
Percentile = SampleQuantile
@@ -549,9 +548,9 @@ def from_proc_descr(
549548
descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr],
550549
member_id: MemberId,
551550
):
552-
axes = _get_axes(descr.kwargs)
551+
dataset_mode, axes = _get_axes(descr.kwargs)
553552

554-
if axes is None or AxisId("batch") in axes:
553+
if dataset_mode:
555554
Mean = DatasetMean
556555
Std = DatasetStd
557556
else:

0 commit comments

Comments
 (0)