@@ -231,7 +231,12 @@ def np_fun(x):
231
231
np .uint32 : 3e-7 , np .float32 : 1e-3 , np .complex64 : 1e-3 ,
232
232
np .float64 : 1e-5 , np .complex128 : 1e-5 }
233
233
tol = jtu .tolerance (dtype , tol_spec )
234
- tol = max (tol , jtu .tolerance (out_dtype , tol_spec )) if out_dtype else tol
234
+ if out_dtype in [np .float16 , dtypes .bfloat16 ]:
235
+ # For 16-bit out_type, NumPy will accumulate in float32, while JAX
236
+ # accumulates in 16-bit, so we need a larger tolerance.
237
+ tol = 1e-1
238
+ else :
239
+ tol = max (tol , jtu .tolerance (out_dtype , tol_spec )) if out_dtype else tol
235
240
self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker ,
236
241
check_dtypes = jnp .bfloat16 not in (dtype , out_dtype ),
237
242
tol = tol )
@@ -930,5 +935,50 @@ def np_op(x, axis=None, dtype=None, include_initial=False):
930
935
self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker )
931
936
self ._CompileAndCheck (jnp_fun , args_maker )
932
937
938
+ @jtu .sample_product (
939
+ op = ['sum' , 'prod' ],
940
+ dtype = ['float16' , 'bfloat16' ],
941
+ )
942
+ def testReducerF16Casts (self , op , dtype ):
943
+ rng = jtu .rand_default (self .rng ())
944
+ x = jnp .asarray (rng ((10 ,), dtype ))
945
+
946
+ func = getattr (jnp , op )
947
+ reduce_p = getattr (jax .lax , f"reduce_{ op } _p" )
948
+ conv_elem_p = jax .lax .convert_element_type_p
949
+
950
+ # Without dtype specified, the reduction is sandwiched between two casts.
951
+ jaxpr1 = jax .make_jaxpr (func )(x )
952
+ self .assertEqual (
953
+ [eqn .primitive for eqn in jaxpr1 .eqns ],
954
+ [conv_elem_p , reduce_p , conv_elem_p ])
955
+
956
+ # With dtype specified, the reduction happens without a cast.
957
+ jaxpr2 = jax .make_jaxpr (partial (func , dtype = dtype ))(x )
958
+ self .assertEqual ([eqn .primitive for eqn in jaxpr2 .eqns ], [reduce_p ])
959
+
960
+ @jtu .sample_product (
961
+ dtype = ['float16' , 'bfloat16' ],
962
+ )
963
+ def testMeanF16Casts (self , dtype ):
964
+ rng = jtu .rand_default (self .rng ())
965
+ x = jnp .asarray (rng ((10 ,), dtype ))
966
+
967
+ reduce_sum_p = jax .lax .reduce_sum_p
968
+ div_p = jax .lax .div_p
969
+ conv_elem_p = jax .lax .convert_element_type_p
970
+
971
+ # Without dtype specified, the reduction is sandwiched between two casts.
972
+ jaxpr1 = jax .make_jaxpr (jnp .mean )(x )
973
+ self .assertEqual (
974
+ [eqn .primitive for eqn in jaxpr1 .eqns ],
975
+ [conv_elem_p , reduce_sum_p , div_p , conv_elem_p ])
976
+
977
+ # With dtype specified, the reduction happens without a cast.
978
+ jaxpr2 = jax .make_jaxpr (partial (jnp .mean , dtype = dtype ))(x )
979
+ self .assertEqual (
980
+ [eqn .primitive for eqn in jaxpr2 .eqns ],
981
+ [reduce_sum_p , div_p ])
982
+
933
983
if __name__ == "__main__" :
934
984
absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments