Skip to content

Commit 616b946

Browse files
committed
jax.numpy reductions: avoid upcast of f16 when dtype is specified by user
1 parent 30acd38 commit 616b946

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

jax/_src/numpy/reductions.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
231231
initial: ArrayLike | None = None, where: ArrayLike | None = None,
232232
promote_integers: bool = True) -> Array:
233233
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
234-
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
234+
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
235235
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
236236
initial=initial, where_=where, parallel_reduce=lax.psum,
237237
promote_integers=promote_integers)
@@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
319319
initial: ArrayLike | None = None, where: ArrayLike | None = None,
320320
promote_integers: bool = True) -> Array:
321321
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
322-
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
322+
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
323323
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
324324
initial=initial, where_=where, promote_integers=promote_integers)
325325

@@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
865865
[6. ]], dtype=float32)
866866
"""
867867
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
868-
where=where)
868+
where=where, upcast_f16_for_computation=(dtype is None))
869869

870-
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
870+
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
871+
inline=True)
871872
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
872873
out: None = None, keepdims: bool = False, *,
873874
upcast_f16_for_computation: bool = True,

tests/lax_numpy_reducers_test.py

+45
Original file line numberDiff line numberDiff line change
@@ -930,5 +930,50 @@ def np_op(x, axis=None, dtype=None, include_initial=False):
930930
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
931931
self._CompileAndCheck(jnp_fun, args_maker)
932932

933+
@jtu.sample_product(
934+
op=['sum', 'prod'],
935+
dtype=['float16', 'bfloat16'],
936+
)
937+
def testReducerF16Casts(self, op, dtype):
938+
rng = jtu.rand_default(self.rng())
939+
x = jnp.asarray(rng((10,), dtype))
940+
941+
func = getattr(jnp, op)
942+
reduce_p = getattr(jax.lax, f"reduce_{op}_p")
943+
conv_elem_p = jax.lax.convert_element_type_p
944+
945+
# Without dtype specified, the reduction is sandwiched between two casts.
946+
jaxpr1 = jax.make_jaxpr(func)(x)
947+
self.assertEqual(
948+
[eqn.primitive for eqn in jaxpr1.eqns],
949+
[conv_elem_p, reduce_p, conv_elem_p])
950+
951+
# With dtype specified, the reduction happens without a cast.
952+
jaxpr2 = jax.make_jaxpr(partial(func, dtype=dtype))(x)
953+
self.assertEqual([eqn.primitive for eqn in jaxpr2.eqns], [reduce_p])
954+
955+
@jtu.sample_product(
956+
dtype=['float16', 'bfloat16'],
957+
)
958+
def testMeanF16Casts(self, dtype):
959+
rng = jtu.rand_default(self.rng())
960+
x = jnp.asarray(rng((10,), dtype))
961+
962+
reduce_sum_p = jax.lax.reduce_sum_p
963+
div_p = jax.lax.div_p
964+
conv_elem_p = jax.lax.convert_element_type_p
965+
966+
# Without dtype specified, the reduction is sandwiched between two casts.
967+
jaxpr1 = jax.make_jaxpr(jnp.mean)(x)
968+
self.assertEqual(
969+
[eqn.primitive for eqn in jaxpr1.eqns],
970+
[conv_elem_p, reduce_sum_p, div_p, conv_elem_p])
971+
972+
# With dtype specified, the reduction happens without a cast.
973+
jaxpr2 = jax.make_jaxpr(partial(jnp.mean, dtype=dtype))(x)
974+
self.assertEqual(
975+
[eqn.primitive for eqn in jaxpr2.eqns],
976+
[reduce_sum_p, div_p])
977+
933978
if __name__ == "__main__":
934979
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)