@@ -71,7 +71,7 @@ def create(cls, axis: AxisLike) -> Axis:
71
71
72
72
@dataclass
73
73
class AxisInfo (Axis ):
74
- maybe_singleton : bool
74
+ maybe_singleton : bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing
75
75
76
76
@classmethod
77
77
def create (cls , axis : AxisLike , maybe_singleton : Optional [bool ] = None ) -> AxisInfo :
@@ -81,17 +81,17 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI
81
81
axis_base = super ().create (axis )
82
82
if maybe_singleton is None :
83
83
if isinstance (axis , Axis ):
84
- maybe_singleton = False
84
+ maybe_singleton = axis . type in ( "batch" , "channel" , "index" )
85
85
elif isinstance (axis , str ):
86
- maybe_singleton = axis == "b"
86
+ maybe_singleton = axis in ( "b" , "c" , "i" )
87
87
else :
88
88
if axis .size is None :
89
89
maybe_singleton = True
90
90
elif isinstance (axis .size , int ):
91
91
maybe_singleton = axis .size == 1
92
92
elif isinstance (axis .size , v0_5 .SizeReference ):
93
93
maybe_singleton = (
94
- False # TODO: check if singleton is ok for a `SizeReference`
94
+ True # TODO: check if singleton is ok for a `SizeReference`
95
95
)
96
96
elif isinstance (
97
97
axis .size , (v0_5 .ParameterizedSize , v0_5 .DataDependentSize )
0 commit comments