Skip to content

Commit 4f1c67e

Browse files
Merge pull request #26403 from jakevdp:bf16-mean
PiperOrigin-RevId: 726157721
2 parents 5b69772 + b5e7b60 commit 4f1c67e

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

jax/_src/numpy/reductions.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
231231
initial: ArrayLike | None = None, where: ArrayLike | None = None,
232232
promote_integers: bool = True) -> Array:
233233
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
234-
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
234+
bool_op=lax.bitwise_or, upcast_f16_for_computation=(dtype is None),
235235
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
236236
initial=initial, where_=where, parallel_reduce=lax.psum,
237237
promote_integers=promote_integers)
@@ -319,7 +319,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
319319
initial: ArrayLike | None = None, where: ArrayLike | None = None,
320320
promote_integers: bool = True) -> Array:
321321
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
322-
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
322+
bool_op=lax.bitwise_and, upcast_f16_for_computation=(dtype is None),
323323
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
324324
initial=initial, where_=where, promote_integers=promote_integers)
325325

@@ -865,9 +865,10 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
865865
[6. ]], dtype=float32)
866866
"""
867867
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
868-
where=where)
868+
where=where, upcast_f16_for_computation=(dtype is None))
869869

870-
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
870+
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'upcast_f16_for_computation'),
871+
inline=True)
871872
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
872873
out: None = None, keepdims: bool = False, *,
873874
upcast_f16_for_computation: bool = True,

tests/lax_numpy_reducers_test.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,12 @@ def np_fun(x):
231231
np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3,
232232
np.float64: 1e-5, np.complex128: 1e-5}
233233
tol = jtu.tolerance(dtype, tol_spec)
234-
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
234+
if out_dtype in [np.float16, dtypes.bfloat16]:
235+
# For 16-bit out_type, NumPy will accumulate in float32, while JAX
236+
# accumulates in 16-bit, so we need a larger tolerance.
237+
tol = 1e-1
238+
else:
239+
tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol
235240
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
236241
check_dtypes=jnp.bfloat16 not in (dtype, out_dtype),
237242
tol=tol)
@@ -930,5 +935,50 @@ def np_op(x, axis=None, dtype=None, include_initial=False):
930935
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
931936
self._CompileAndCheck(jnp_fun, args_maker)
932937

938+
@jtu.sample_product(
939+
op=['sum', 'prod'],
940+
dtype=['float16', 'bfloat16'],
941+
)
942+
def testReducerF16Casts(self, op, dtype):
943+
rng = jtu.rand_default(self.rng())
944+
x = jnp.asarray(rng((10,), dtype))
945+
946+
func = getattr(jnp, op)
947+
reduce_p = getattr(jax.lax, f"reduce_{op}_p")
948+
conv_elem_p = jax.lax.convert_element_type_p
949+
950+
# Without dtype specified, the reduction is sandwiched between two casts.
951+
jaxpr1 = jax.make_jaxpr(func)(x)
952+
self.assertEqual(
953+
[eqn.primitive for eqn in jaxpr1.eqns],
954+
[conv_elem_p, reduce_p, conv_elem_p])
955+
956+
# With dtype specified, the reduction happens without a cast.
957+
jaxpr2 = jax.make_jaxpr(partial(func, dtype=dtype))(x)
958+
self.assertEqual([eqn.primitive for eqn in jaxpr2.eqns], [reduce_p])
959+
960+
@jtu.sample_product(
961+
dtype=['float16', 'bfloat16'],
962+
)
963+
def testMeanF16Casts(self, dtype):
964+
rng = jtu.rand_default(self.rng())
965+
x = jnp.asarray(rng((10,), dtype))
966+
967+
reduce_sum_p = jax.lax.reduce_sum_p
968+
div_p = jax.lax.div_p
969+
conv_elem_p = jax.lax.convert_element_type_p
970+
971+
# Without dtype specified, the reduction is sandwiched between two casts.
972+
jaxpr1 = jax.make_jaxpr(jnp.mean)(x)
973+
self.assertEqual(
974+
[eqn.primitive for eqn in jaxpr1.eqns],
975+
[conv_elem_p, reduce_sum_p, div_p, conv_elem_p])
976+
977+
# With dtype specified, the reduction happens without a cast.
978+
jaxpr2 = jax.make_jaxpr(partial(jnp.mean, dtype=dtype))(x)
979+
self.assertEqual(
980+
[eqn.primitive for eqn in jaxpr2.eqns],
981+
[reduce_sum_p, div_p])
982+
933983
if __name__ == "__main__":
934984
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)