You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It appears that the dtype argument for layers like LayerNormDenseGeneral and MultiHeadAttention is being ignored.
The argument is documented as follows, and so I would expect the parameters to be initialised to bfloat16 in the sample code below, but this isn't the case:
dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.
I have tested this with transformer-engine 1.14.0, which comes bundled with the nvcr.io/nvidia/jax:25.01-py3 docker image. I don't see a release for this version on GitHub yet, though. I have also tested with 1.12.0, which yielded the same results.
It appears that the
dtype
argument for layers likeLayerNormDenseGeneral
andMultiHeadAttention
is being ignored.The argument is documented as follows, and so I would expect the parameters to be initialised to
bfloat16
in the sample code below, but this isn't the case:I have tested this with transformer-engine 1.14.0, which comes bundled with the
nvcr.io/nvidia/jax:25.01-py3
docker image. I don't see a release for this version on GitHub yet, though. I have also tested with 1.12.0, which yielded the same results.Printout from
jax.print_environment_info()
The text was updated successfully, but these errors were encountered: