@@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
231
231
initial : ArrayLike | None = None , where : ArrayLike | None = None ,
232
232
promote_integers : bool = True ) -> Array :
233
233
return _reduction (a , "sum" , lax .add , 0 , preproc = _cast_to_numeric ,
234
- bool_op = lax .bitwise_or , upcast_f16_for_computation = True ,
234
+ bool_op = lax .bitwise_or , upcast_f16_for_computation = ( dtype is None ) ,
235
235
axis = axis , dtype = dtype , out = out , keepdims = keepdims ,
236
236
initial = initial , where_ = where , parallel_reduce = lax .psum ,
237
237
promote_integers = promote_integers )
@@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
319
319
initial : ArrayLike | None = None , where : ArrayLike | None = None ,
320
320
promote_integers : bool = True ) -> Array :
321
321
return _reduction (a , "prod" , lax .mul , 1 , preproc = _cast_to_numeric ,
322
- bool_op = lax .bitwise_and , upcast_f16_for_computation = True ,
322
+ bool_op = lax .bitwise_and , upcast_f16_for_computation = ( dtype is None ) ,
323
323
axis = axis , dtype = dtype , out = out , keepdims = keepdims ,
324
324
initial = initial , where_ = where , promote_integers = promote_integers )
325
325
@@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
865
865
[6. ]], dtype=float32)
866
866
"""
867
867
return _mean (a , _ensure_optional_axes (axis ), dtype , out , keepdims ,
868
- where = where )
868
+ where = where , upcast_f16_for_computation = ( dtype is None ) )
869
869
870
- @partial (api .jit , static_argnames = ('axis' , 'dtype' , 'keepdims' ), inline = True )
870
+ @partial (api .jit , static_argnames = ('axis' , 'dtype' , 'keepdims' , 'upcast_f16_for_computation' ),
871
+ inline = True )
871
872
def _mean (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
872
873
out : None = None , keepdims : bool = False , * ,
873
874
upcast_f16_for_computation : bool = True ,
0 commit comments