Skip to content

Commit 1aad5f1

Browse files
committed
jax.numpy reductions: avoid upcast of f16 when dtype is specified by user
1 parent 30acd38 commit 1aad5f1

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

jax/_src/numpy/reductions.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
231231
initial: ArrayLike | None = None, where: ArrayLike | None = None,
232232
promote_integers: bool = True) -> Array:
233233
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
234-
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
234+
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
235235
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
236236
initial=initial, where_=where, parallel_reduce=lax.psum,
237237
promote_integers=promote_integers)
@@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
319319
initial: ArrayLike | None = None, where: ArrayLike | None = None,
320320
promote_integers: bool = True) -> Array:
321321
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
322-
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
322+
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
323323
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
324324
initial=initial, where_=where, promote_integers=promote_integers)
325325

@@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
865865
[6. ]], dtype=float32)
866866
"""
867867
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
868-
where=where)
868+
where=where, upcast_f16_for_computation=(dtype is None))
869869

870-
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
870+
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
871+
inline=True)
871872
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
872873
out: None = None, keepdims: bool = False, *,
873874
upcast_f16_for_computation: bool = True,

0 commit comments

Comments
 (0)