@@ -99,18 +99,20 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
99
99
if not all (params .type .broadcastable ):
100
100
return None
101
101
102
- # Check whether axis covers all dimensions
103
- axis = set (node .op .axis )
104
- base_var_dims = set (range (base_var .ndim ))
105
- if axis != base_var_dims :
106
- return None
102
+ if node .op .axis is None :
103
+ axis = tuple (range (base_var .ndim ))
104
+ else :
105
+ # Check whether axis covers all dimensions
106
+ axis = tuple (sorted (node .op .axis ))
107
+ if axis != tuple (range (base_var .ndim )):
108
+ return None
107
109
108
110
# distinguish measurable discrete and continuous (because logprob is different)
109
111
measurable_max : Max
110
112
if base_var .type .dtype .startswith ("int" ):
111
- measurable_max = MeasurableMaxDiscrete (list ( axis ) )
113
+ measurable_max = MeasurableMaxDiscrete (axis )
112
114
else :
113
- measurable_max = MeasurableMax (list ( axis ) )
115
+ measurable_max = MeasurableMax (axis )
114
116
115
117
max_rv_node = measurable_max .make_node (base_var )
116
118
max_rv = max_rv_node .outputs
@@ -206,21 +208,23 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVa
206
208
if not all (params .type .broadcastable ):
207
209
return None
208
210
209
- # Check whether axis is supported or not
210
- axis = set (node .op .axis )
211
- base_var_dims = set (range (base_var .ndim ))
212
- if axis != base_var_dims :
213
- return None
211
+ if node .op .axis is None :
212
+ axis = tuple (range (base_var .ndim ))
213
+ else :
214
+ # Check whether axis is supported or not
215
+ axis = tuple (sorted (node .op .axis ))
216
+ if axis != tuple (range (base_var .ndim )):
217
+ return None
214
218
215
219
if not rv_map_feature .request_measurable ([base_rv ]):
216
220
return None
217
221
218
222
# distinguish measurable discrete and continuous (because logprob is different)
219
223
measurable_min : Max
220
224
if base_rv .type .dtype .startswith ("int" ):
221
- measurable_min = MeasurableDiscreteMaxNeg (list ( axis ) )
225
+ measurable_min = MeasurableDiscreteMaxNeg (axis )
222
226
else :
223
- measurable_min = MeasurableMaxNeg (list ( axis ) )
227
+ measurable_min = MeasurableMaxNeg (axis )
224
228
225
229
return measurable_min .make_node (base_rv ).outputs
226
230
0 commit comments