Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect parameter dtype initialisation for flax transformer-engine modules #1451

Open
liamclarkza opened this issue Feb 3, 2025 · 0 comments
Assignees

Comments

@liamclarkza
Copy link

liamclarkza commented Feb 3, 2025

It appears that the dtype argument for layers like LayerNormDenseGeneral and MultiHeadAttention is being ignored.

The argument is documented as follows, and so I would expect the parameters to be initialised to bfloat16 in the sample code below, but this isn't the case:

dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

from typing import Any

import flax.linen as nn
import jax
import jax.numpy as jnp
import transformer_engine.jax.flax as te_flax


class Model(nn.Module):
    embed_dim: int
    param_dtype: Any

    def setup(self):
        self.layer1 = te_flax.LayerNormDenseGeneral(
            self.embed_dim,
            return_layernorm_output=False,
            transpose_batch_sequence=False,
            dtype=self.param_dtype,
        )
        self.layer2 = te_flax.MultiHeadAttention(
            head_dim=self.embed_dim,
            num_attention_heads=8,
            dtype=self.param_dtype,
        )

    def __call__(self, x):
        x = self.layer1(x)[0]
        x = self.layer2(x, x)[0]
        return x


compute_dtype = jnp.bfloat16
param_dtype = jnp.bfloat16

x = jnp.ones((8, 16, 128), dtype=compute_dtype)
model = Model(
    embed_dim=128,
    param_dtype=param_dtype,
)

print(f"Test: {compute_dtype=}, {param_dtype=}")
print(model.tabulate(jax.random.key(0), x, console_kwargs={"width": 200}))

params = model.init(jax.random.key(0),  x)

# Double check param dtypes
jax.tree.map_with_path(
    lambda k,v: print(f"{'/'.join(p.key for p in k):30s}\t expected dtype: {compute_dtype.dtype}\t got dtype: {v.dtype}"),
    params
)
y = jax.jit(model.apply)(params, x)
print(f"{y.dtype=}\t {y.shape=}")
Test: compute_dtype=<class 'jax.numpy.bfloat16'>, param_dtype=<class 'jax.numpy.bfloat16'>

                                                                                          Model Summary                                                                                          
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path                                                     ┃ module                    ┃ inputs                   ┃ outputs                 ┃ params_axes         ┃ params                      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│                                                          │ Model                     │ bfloat16[8,16,128]       │ bfloat16[8,16,128]      │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer1                                                   │ LayerNormDenseGeneral     │ bfloat16[8,16,128]       │ - bfloat16[8,16,128]    │ kernel_axes:        │ kernel: float32[128,128]    │
│                                                          │                           │                          │ - None                  │   names: []         │ ln_bias: float32[128]       │
│                                                          │                           │                          │                         │ ln_bias_axes:       │ scale: float32[128]         │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - embed           │ 16,640 (66.6 KB)            │
│                                                          │                           │                          │                         │ scale_axes:         │                             │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - embed           │                             │
│                                                          │                           │                          │                         │                     │                             │
│                                                          │                           │                          │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2                                                   │ MultiHeadAttention        │ - bfloat16[8,16,128]     │ - bfloat16[8,16,128]    │                     │                             │
│                                                          │                           │ - bfloat16[8,16,128]     │ - None                  │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/qkv                                               │ LayerNormDenseGeneral     │ bfloat16[8,16,128]       │ - bfloat16[8,16,3,1024] │ kernel_axes:        │ kernel: float32[128,3,1024] │
│                                                          │                           │                          │ - None                  │   names:            │ ln_bias: float32[128]       │
│                                                          │                           │                          │                         │   - nvte_w_fsdp     │ scale: float32[128]         │
│                                                          │                           │                          │                         │   - nvte_w_joined   │                             │
│                                                          │                           │                          │                         │   - nvte_w_tp       │ 393,472 (1.6 MB)            │
│                                                          │                           │                          │                         │ ln_bias_axes:       │                             │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - nvte_w_no_shard │                             │
│                                                          │                           │                          │                         │ scale_axes:         │                             │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - nvte_w_no_shard │                             │
│                                                          │                           │                          │                         │                     │                             │
│                                                          │                           │                          │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/DotProductAttention_0                             │ DotProductAttention       │ - bfloat16[8,16,3,8,128] │ bfloat16[8,16,8,128]    │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - deterministic: False   │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/DotProductAttention_0/_FusedDotProductAttention_0 │ _FusedDotProductAttention │ - bfloat16[8,16,3,8,128] │ bfloat16[8,16,8,128]    │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - deterministic: False   │                         │                     │                             │
│                                                          │                           │   dropout_rng: None      │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/out                                               │ DenseGeneral              │ bfloat16[8,16,1024]      │ bfloat16[8,16,128]      │ kernel_axes:        │ kernel: float32[1024,128]   │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - nvte_w_tp       │ 131,072 (524.3 KB)          │
│                                                          │                           │                          │                         │   - nvte_w_fsdp     │                             │
│                                                          │                           │                          │                         │                     │                             │
│                                                          │                           │                          │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│                                                          │                           │                          │                   Total │                     │ 541,184 (2.2 MB)            │
└──────────────────────────────────────────────────────────┴───────────────────────────┴──────────────────────────┴─────────────────────────┴─────────────────────┴─────────────────────────────┘
                                                                                                                                                                                                 
                                                                               Total Parameters: 541,184 (2.2 MB)                                                 


params/layer1/kernel          	 expected dtype: bfloat16	 got dtype: float32
params/layer1/ln_bias         	 expected dtype: bfloat16	 got dtype: float32
params/layer1/scale           	 expected dtype: bfloat16	 got dtype: float32
params/layer2/out/kernel      	 expected dtype: bfloat16	 got dtype: float32
params/layer2/qkv/kernel      	 expected dtype: bfloat16	 got dtype: float32
params/layer2/qkv/ln_bias     	 expected dtype: bfloat16	 got dtype: float32
params/layer2/qkv/scale       	 expected dtype: bfloat16	 got dtype: float32
y.dtype=dtype(bfloat16)	 y.shape=(8, 16, 128)

I have tested this with transformer-engine 1.14.0, which comes bundled with the nvcr.io/nvidia/jax:25.01-py3 docker image. I don't see a release for this version on GitHub yet, though. I have also tested with 1.12.0, which yielded the same results.

Printout from jax.print_environment_info()

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0]
device info: NVIDIA H100 80GB HBM3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='experiment-7a880fc8-f456-head', release='6.8.0-49-generic', version='#49~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Nov  6 17:42:15 UTC 2', machine='x86_64')


$ nvidia-smi
Mon Feb  3 12:30:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:19:00.0 Off |                    0 |
| N/A   41C    P0            120W /  700W |     550MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:3B:00.0 Off |                    0 |
| N/A   37C    P0            121W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   34C    P0            116W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:5D:00.0 Off |                    0 |
| N/A   38C    P0            119W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   41C    P0            126W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:BB:00.0 Off |                    0 |
| N/A   37C    P0            119W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:CB:00.0 Off |                    0 |
| N/A   38C    P0            113W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:DB:00.0 Off |                    0 |
| N/A   36C    P0            117W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    1   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    2   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    3   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    4   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    5   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    6   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    7   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
+-----------------------------------------------------------------------------------------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants