Skip to content

Conversation

@kliuae-amd
Copy link

Motivation

This PR addresses the incorrectly configured matA/B scales of hipb_mm in B-preshuffled, row-wise FP8 mode.
It uses the HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F enum introduced since hipblaslt >= 1.0.0 to set scale mode, so in compilation in guards hipblaslt versions and raises runtime errors for incompatible hipblaslt versions.

Technical Details

Test Plan

Unit test in op_tests/test_gemm_a8w8.py on MI300X
python op_tests/test_gemm_a8w8.py -d bf16 -q fp8

Test Result

Before fix

[aiter] summary:                                                                                                                                  
             dtype      m     n     k             quantDtype    ......      hipmm bpreshuffle us  hipmm bpreshuffle err                          
0   torch.bfloat16      1  1280  8192  torch.float8_e4m3fnuz    ......                  9.076333               0.936719                          
1   torch.bfloat16     32  1280  8192  torch.float8_e4m3fnuz    ......                  9.860040               0.939575                          
2   torch.bfloat16     64  1280  8192  torch.float8_e4m3fnuz    ......                 10.005537               0.911316                          
3   torch.bfloat16    128  1280  8192  torch.float8_e4m3fnuz    ......                 13.438359               0.904913                          
4   torch.bfloat16    192  1280  8192  torch.float8_e4m3fnuz    ......                 15.519562               0.967102                          
5   torch.bfloat16    256  1280  8192  torch.float8_e4m3fnuz    ......                 17.775519               0.907144                          
6   torch.bfloat16    320  1280  8192  torch.float8_e4m3fnuz    ......                 19.356704               0.969473                          
7   torch.bfloat16    512  1280  8192  torch.float8_e4m3fnuz    ......                 24.966711               0.980595                          
8   torch.bfloat16   1024  1280  8192  torch.float8_e4m3fnuz    ......                 36.096970               0.964364                          
9   torch.bfloat16   2048  1280  8192  torch.float8_e4m3fnuz    ......                 57.476949               0.936559                          
10  torch.bfloat16   4096  1280  8192  torch.float8_e4m3fnuz    ......                 96.516929               0.954089                          
11  torch.bfloat16   8192  1280  8192  torch.float8_e4m3fnuz    ......                173.112388               0.914561                          
12  torch.bfloat16  16384  1280  8192  torch.float8_e4m3fnuz    ......                334.552804               0.920774                          
13  torch.bfloat16      1  8192  1024  torch.float8_e4m3fnuz    ......                  7.357798               0.875854
14  torch.bfloat16     32  8192  1024  torch.float8_e4m3fnuz    ......                  6.527495               0.910431                          
15  torch.bfloat16     64  8192  1024  torch.float8_e4m3fnuz    ......                  7.436687               0.978971                          
16  torch.bfloat16    128  8192  1024  torch.float8_e4m3fnuz    ......                  8.733949               0.949270                          
17  torch.bfloat16    192  8192  1024  torch.float8_e4m3fnuz    ......                  9.874323               0.909483                          
18  torch.bfloat16    256  8192  1024  torch.float8_e4m3fnuz    ......                 11.215212               0.991047                          
19  torch.bfloat16    320  8192  1024  torch.float8_e4m3fnuz    ......                 11.850485               0.917302                          
20  torch.bfloat16    512  8192  1024  torch.float8_e4m3fnuz    ......                 16.191485               0.927621
21  torch.bfloat16   1024  8192  1024  torch.float8_e4m3fnuz    ......                 26.278051               0.953078                          
22  torch.bfloat16   2048  8192  1024  torch.float8_e4m3fnuz    ......                 47.027571               0.912666
23  torch.bfloat16   4096  8192  1024  torch.float8_e4m3fnuz    ......                 92.053571               0.915039
24  torch.bfloat16   8192  8192  1024  torch.float8_e4m3fnuz    ......                160.548281               0.913475
25  torch.bfloat16  16384  8192  1024  torch.float8_e4m3fnuz    ......                315.413600               0.918916
26  torch.bfloat16     16  7424  8192  torch.float8_e4m3fnuz    ......                 18.439821               0.974197
27  torch.bfloat16     32  7424  8192  torch.float8_e4m3fnuz    ......                 20.613468               0.957360
28  torch.bfloat16     48  7424  8192  torch.float8_e4m3fnuz    ......                 26.358021               0.903660
29  torch.bfloat16     64  7424  8192  torch.float8_e4m3fnuz    ......                 26.915433               0.939001
30  torch.bfloat16   4096  7424  8192  torch.float8_e4m3fnuz    ......                475.293602               0.911551
31  torch.bfloat16   5120  7424  8192  torch.float8_e4m3fnuz    ......                539.614287               0.918253
32  torch.bfloat16   8192  7424  8192  torch.float8_e4m3fnuz    ......                901.106811               0.930693
cu_count=304

After the fix

[aiter] summary:                                                                                                                                              
             dtype      m     n     k             quantDtype    ......      hipmm bpreshuffle us  hipmm bpreshuffle err
0   torch.bfloat16      1  1280  8192  torch.float8_e4m3fnuz    ......                  9.293663               0.000000                          
1   torch.bfloat16     32  1280  8192  torch.float8_e4m3fnuz    ......                 10.013390               0.000000
2   torch.bfloat16     64  1280  8192  torch.float8_e4m3fnuz    ......                 10.231184               0.000000                          
3   torch.bfloat16    128  1280  8192  torch.float8_e4m3fnuz    ......                 13.525584               0.000012
4   torch.bfloat16    192  1280  8192  torch.float8_e4m3fnuz    ......                 15.700766               0.000012                          
5   torch.bfloat16    256  1280  8192  torch.float8_e4m3fnuz    ......                 18.110375               0.000015
6   torch.bfloat16    320  1280  8192  torch.float8_e4m3fnuz    ......                 19.618397               0.000029                          
7   torch.bfloat16    512  1280  8192  torch.float8_e4m3fnuz    ......                 25.294342               0.000018
8   torch.bfloat16   1024  1280  8192  torch.float8_e4m3fnuz    ......                 36.256869               0.000036                          
9   torch.bfloat16   2048  1280  8192  torch.float8_e4m3fnuz    ......                 57.973612               0.000029
10  torch.bfloat16   4096  1280  8192  torch.float8_e4m3fnuz    ......                 97.578092               0.000036                          
11  torch.bfloat16   8192  1280  8192  torch.float8_e4m3fnuz    ......                171.985725               0.000041
12  torch.bfloat16  16384  1280  8192  torch.float8_e4m3fnuz    ......                335.293327               0.000043                          
13  torch.bfloat16      1  8192  1024  torch.float8_e4m3fnuz    ......                  7.752010               0.000000
14  torch.bfloat16     32  8192  1024  torch.float8_e4m3fnuz    ......                  6.866939               0.000004                          
15  torch.bfloat16     64  8192  1024  torch.float8_e4m3fnuz    ......                  7.745848               0.000015
16  torch.bfloat16    128  8192  1024  torch.float8_e4m3fnuz    ......                  9.052798               0.000017                          
17  torch.bfloat16    192  8192  1024  torch.float8_e4m3fnuz    ......                 10.212596               0.000014
18  torch.bfloat16    256  8192  1024  torch.float8_e4m3fnuz    ......                 11.509394               0.000016                          
19  torch.bfloat16    320  8192  1024  torch.float8_e4m3fnuz    ......                 12.343394               0.000010
20  torch.bfloat16    512  8192  1024  torch.float8_e4m3fnuz    ......                 16.790705               0.000015                          
21  torch.bfloat16   1024  8192  1024  torch.float8_e4m3fnuz    ......                 26.965495               0.000016
22  torch.bfloat16   2048  8192  1024  torch.float8_e4m3fnuz    ......                 51.254551               0.000014
23  torch.bfloat16   4096  8192  1024  torch.float8_e4m3fnuz    ......                 89.330347               0.000015
24  torch.bfloat16   8192  8192  1024  torch.float8_e4m3fnuz    ......                175.018290               0.000015
25  torch.bfloat16  16384  8192  1024  torch.float8_e4m3fnuz    ......                320.445516               0.000014
26  torch.bfloat16     16  7424  8192  torch.float8_e4m3fnuz    ......                 18.641989               0.000008
27  torch.bfloat16     32  7424  8192  torch.float8_e4m3fnuz    ......                 20.872163               0.000013
28  torch.bfloat16     48  7424  8192  torch.float8_e4m3fnuz    ......                 26.644844               0.000031
29  torch.bfloat16     64  7424  8192  torch.float8_e4m3fnuz    ......                 26.901010               0.000015
30  torch.bfloat16   4096  7424  8192  torch.float8_e4m3fnuz    ......                476.617787               0.000041
31  torch.bfloat16   5120  7424  8192  torch.float8_e4m3fnuz    ......                542.135347               0.000041
32  torch.bfloat16   8192  7424  8192  torch.float8_e4m3fnuz    ......                914.271957               0.000040
cu_count=304

Submission Checklist

@valarLip
Copy link
Collaborator

nice job, could you please fix the lint error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants