41
41
from .tensor import Tensor
42
42
43
43
44
- def convert_axis_ids (
45
- axes : Union [ Sequence [ AxisId ], v0_4 .AxesInCZYX ] ,
44
+ def _convert_axis_ids (
45
+ axes : v0_4 .AxesInCZYX ,
46
46
mode : Literal ["per_sample" , "per_dataset" ],
47
47
) -> Tuple [AxisId , ...]:
48
48
if not isinstance (axes , str ):
49
49
return tuple (axes )
50
50
51
- axis_map = dict (b = AxisId ("batch" ), c = AxisId ("channel" ), i = AxisId ("index" ))
52
51
if mode == "per_sample" :
53
52
ret = []
54
53
elif mode == "per_dataset" :
55
- ret = [AxisId ("batch " )]
54
+ ret = [AxisId ("b " )]
56
55
else :
57
56
assert_never (mode )
58
57
59
- ret .extend ([axis_map . get ( a , AxisId (a ) ) for a in axes ])
58
+ ret .extend ([AxisId (a ) for a in axes ])
60
59
return tuple (ret )
61
60
62
61
@@ -375,7 +374,7 @@ def from_proc_descr(
375
374
member_id : MemberId ,
376
375
) -> Self :
377
376
kwargs = descr .kwargs
378
- axes = _get_axes (descr .kwargs )
377
+ _ , axes = _get_axes (descr .kwargs )
379
378
380
379
return cls (
381
380
input = member_id ,
@@ -395,18 +394,18 @@ def _get_axes(
395
394
v0_4 .ScaleMeanVarianceKwargs ,
396
395
v0_5 .ScaleMeanVarianceKwargs ,
397
396
]
398
- ) -> Union [ Tuple [AxisId , ...], None ]:
397
+ ) -> Tuple [ bool , Optional [ Tuple [AxisId , ...]] ]:
399
398
if kwargs .axes is None :
400
- axes = None
399
+ return True , None
401
400
elif isinstance (kwargs .axes , str ):
402
- axes = convert_axis_ids (kwargs .axes , kwargs ["mode" ])
401
+ axes = _convert_axis_ids (kwargs .axes , kwargs ["mode" ])
402
+ return AxisId ("b" ) in axes , axes
403
403
elif isinstance (kwargs .axes , collections .abc .Sequence ):
404
404
axes = tuple (kwargs .axes )
405
+ return AxisId ("batch" ) in axes , axes
405
406
else :
406
407
assert_never (kwargs .axes )
407
408
408
- return axes
409
-
410
409
411
410
@dataclass
412
411
class ScaleRange (_SimpleOperator ):
@@ -458,8 +457,8 @@ def from_proc_descr(
458
457
if kwargs .reference_tensor is None
459
458
else MemberId (str (kwargs .reference_tensor ))
460
459
)
461
- axes = _get_axes (descr .kwargs )
462
- if axes is None or AxisId ( "batch" ) in axes :
460
+ dataset_mode , axes = _get_axes (descr .kwargs )
461
+ if dataset_mode :
463
462
Percentile = DatasetPercentile
464
463
else :
465
464
Percentile = SampleQuantile
@@ -549,9 +548,9 @@ def from_proc_descr(
549
548
descr : Union [v0_4 .ZeroMeanUnitVarianceDescr , v0_5 .ZeroMeanUnitVarianceDescr ],
550
549
member_id : MemberId ,
551
550
):
552
- axes = _get_axes (descr .kwargs )
551
+ dataset_mode , axes = _get_axes (descr .kwargs )
553
552
554
- if axes is None or AxisId ( "batch" ) in axes :
553
+ if dataset_mode :
555
554
Mean = DatasetMean
556
555
Std = DatasetStd
557
556
else :
0 commit comments