Skip to content

Feat: Implementation of the DeepSeek blockwise quantization for fp8 tensors #1763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Degnel
Copy link

@Degnel Degnel commented Feb 22, 2025

This PR is the first step towards addressing issue #1594. It includes the following implementations:

  • fp8 triton gemm for blockwise quantisation
  • quant, dequant and linear utilities
  • time & precision benchmarks
  • basic tests

If the code is validated, it would be great to bench it on H100.

Copy link

pytorch-bot bot commented Feb 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1763

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0af2b9b with merge base 137b079 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 22, 2025
@Degnel Degnel mentioned this pull request Feb 22, 2025
@danielvegamyhre
Copy link
Contributor

Thanks for your work on this! I'll take a closer look next week.

cc @vkuzo @drisspg

@Degnel
Copy link
Author

Degnel commented Feb 25, 2025

Thanks for running the tests. I have two questions regarding the errors:

  • Where should I add Triton to allow the tests to run successfully without introducing unnecessary dependencies in dev-requirements.txt?
  • Does torchao provide any utility to check the available FP8 types for each gpu architecture?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 27, 2025

Thanks for running the tests. I have two questions regarding the errors:

  • Where should I add Triton to allow the tests to run successfully without introducing unnecessary dependencies in dev-requirements.txt?

Can you clarify what you mean? Are tests failing in CI due to a missing triton installation? That shouldn't be happening, please share the link/logs if so.

  • Does torchao provide any utility to check the available FP8 types for each gpu architecture?

We just use helpers which skip tests if GPU architecture is not at least SM 89:

def is_sm_at_least_89():

You can find examples in the float8 tests (example).

@danielvegamyhre danielvegamyhre self-assigned this Feb 27, 2025
@danielvegamyhre danielvegamyhre self-requested a review February 27, 2025 17:45
@Degnel
Copy link
Author

Degnel commented Feb 28, 2025

Can you clarify what you mean? Are tests failing in CI due to a missing triton installation? That shouldn't be happening, please share the link/logs if so.

Indeed, they are. It looks like only the CPU runs are failing. I presume that bitsandbytes might not install triton when no GPU is available (I might be missing something there). Here is an instance of a failing log:

https://github.com/pytorch/ao/actions/runs/13484452669/job/37730985419?pr=1763#step:14:1276

We just use helpers which skip tests if GPU architecture is not at least SM 89:

def is_sm_at_least_89():

You can find examples in the float8 tests (example).

Thank you for the hint, I've locally updated the code accordingly 👍

W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

quantize_(lin, int8_dynamic_activation_int4_weight())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is int8_dynamic_activation_int4_weight being used here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank's for noticing it. I was aiming for a static W4A8 quantization and I overlooked that it was dynamic. I will try to address this within the week.

@danielvegamyhre
Copy link
Contributor

Can you clarify what you mean? Are tests failing in CI due to a missing triton installation? That shouldn't be happening, please share the link/logs if so.

Also @Degnel you should skip tests requiring triton if CUDA is not available.

@danielvegamyhre
Copy link
Contributor

@Degnel thanks for your work on this, i ran the tests and it looks like your blockwise fp8 gemm test is failing due to quantization error

@Degnel
Copy link
Author

Degnel commented Mar 7, 2025

@Degnel thanks for your work on this, i ran the tests and it looks like your blockwise fp8 gemm test is failing due to quantization error

Thanks for pointing that out! I had also noticed the issue, and I think I was just a bit too harsh with the threshold. I'll increase it to make it more reasonable. That said, I'll still double-check the calculations manually to ensure everything is mathematically correct.

@Degnel
Copy link
Author

Degnel commented Mar 13, 2025

@danielvegamyhre I believe that everything should be alright except for the PR Label Check (I'm not sure if I have the required rights to edit this). The test-mps-ops (macos-m1-stable) failed, but I think that the merge will fix it as it seems to be a newly introduced test.

@danielvegamyhre danielvegamyhre added the topic: new feature Use this tag if this PR adds a new feature label Mar 13, 2025
@Degnel
Copy link
Author

Degnel commented Mar 14, 2025

The test-mps-ops (macos-m1-stable) failed once again. I've seen other recent PRs both succeeding and failing this test (due to the same missing package 'importlib_metadata'). I don't think this is related to the code I wrote, but I might be missing something. Please, let me know if you have any insights.

@drisspg
Copy link
Contributor

drisspg commented Mar 14, 2025

the test mps is unrleated, re-running tests

@Degnel
Copy link
Author

Degnel commented Apr 21, 2025

It seems like the new PRs are not failing anymore due to the macOS tests. Maybe we should try to rerun it here :) @danielvegamyhre @drisspg

@drisspg
Copy link
Contributor

drisspg commented Apr 24, 2025

Sorry, could you do 1 more rebase to kick back off ci

Degnel added 7 commits April 25, 2025 14:08
- fp8 triton gemm
- quant, dequant and linear utils
- time & precision benchmarks
- basic tests
- removing triton dependency
- cleanning adaptative dtype
- fixing W4A8 quantization for cutlass kernel in precision benchmark
- importing triton only if cuda available
- setting a less harsh threshold for quant-dequant and for gemm kernel mm precision
- condition triton import in gemm
- linting
@Degnel Degnel force-pushed the feat/blockwise_fp8_quant_triton_gemm_ker branch from e8edea9 to e41457c Compare April 25, 2025 12:11
@Degnel
Copy link
Author

Degnel commented Apr 25, 2025

Sorry, could you do 1 more rebase to kick back off ci

No problem, it should be ok

@Degnel
Copy link
Author

Degnel commented Apr 26, 2025

Thank you @drisspg I've made the linting

@drisspg
Copy link
Contributor

drisspg commented Apr 26, 2025

cc @danielvegamyhre can you carry this across the finish line

Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress @Degnel and sorry for the delay! Just did another pass and left a couple comments.

W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

qact = _int8_symm_per_token_reduced_range_quant_cutlass(A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are int4 and int8 quant being used here?

Note these no longer exist in torchao (_int4_symm_per_token_quant_cutlass and _int8_symm_per_token_reduced_range_quant_cutlass) so they should be removed/changed but I'm unclear why int4/int8 quant is being used in the first place, can you clarify?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be transparent, this code was written a few months ago, and I don't fully recall the rationale behind using those specific quantization modes. At the time, the goal was to establish a baseline for comparison across different quantization strategies, and int4/int8 quantization options seemed like a good starting point. Looking back, I agree it was probably a poor choice. I've removed the deprecated imports.

Degnel and others added 3 commits April 29, 2025 14:52
- raising explicit error when running benchmark without cuda
- merging quant, dequant and gemm code into one file
- removing depricated int4/int8 comparison
…:Degnel/ao into feat/blockwise_fp8_quant_triton_gemm_ker
@danielvegamyhre
Copy link
Contributor

@Degnel I get an import error when trying to run the benchmarks, can you take a look?

(torchtitan) [[email protected] ~/ao/benchmarks (feat/blockwise_fp8_quant_triton_gemm_ker)]$ python benchmark_blockwise_scaled_linear_triton.py 
Traceback (most recent call last):
  File "/home/danvm/ao/benchmarks/benchmark_blockwise_scaled_linear_triton.py", line 15, in <module>
    from torchao.prototype.blockwise_fp8.blockwise_quantization import (
    ...<3 lines>...
    )
  File "/home/danvm/ao/torchao/prototype/blockwise_fp8/__init__.py", line 1, in <module>
    from .blockwise_fp8_gemm_triton import blockwise_fp8_gemm
ModuleNotFoundError: No module named 'torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton'

- fix import in __init__.py and in blockwise_linear.py
@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Apr 30, 2025

Interesting, I see solid speedup vs fp16 for small M but it falls off as M increases, to the point where fp8 blockwise is actually 4x slower than fp16 for larger shapes, which is the opposite of what I would expect... this on H100s for context.

Maybe block sizes in autotuner are not optimal for shapes of this size? What do you think @Degnel @drisspg ? I suppose we can merge this is an initial start but we should profile this and fix the perf for larger shapes.

|   m |     k |     n |   block_size | dtype               |   fp16_latency (ms) |   blockwise_latency (ms) |   blockwise_speedup |
|----:|------:|------:|-------------:|:--------------------|--------------------:|-------------------------:|--------------------:|
|   1 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.144 |                   52.256 |            1.57195  |
|   1 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              95.936 |                   61.184 |            1.56799  |
|   1 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             429.472 |                  233.712 |            1.83761  |
|   1 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             245.44  |                  131.296 |            1.86936  |
|   1 |  8192 |  8192 |          128 | torch.float8_e5m2   |              82.208 |                   52.704 |            1.55981  |
|   1 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.192 |                   61.056 |            1.57547  |
|   1 |  8192 | 57344 |          128 | torch.float8_e5m2   |             430.944 |                  234.08  |            1.84101  |
|   1 | 28672 |  8192 |          128 | torch.float8_e5m2   |             244.736 |                  131.104 |            1.86673  |
|   2 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.048 |                   53.92  |            1.52166  |
|   2 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.448 |                   62.016 |            1.55521  |
|   2 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             432.16  |                  233.504 |            1.85076  |
|   2 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             228.544 |                  135.808 |            1.68285  |
|   2 |  8192 |  8192 |          128 | torch.float8_e5m2   |              81.984 |                   54.048 |            1.51687  |
|   2 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.352 |                   61.952 |            1.55527  |
|   2 |  8192 | 57344 |          128 | torch.float8_e5m2   |             432.544 |                  235.744 |            1.8348   |
|   2 | 28672 |  8192 |          128 | torch.float8_e5m2   |             227.328 |                  136.512 |            1.66526  |
|   4 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.816 |                   54.336 |            1.52415  |
|   4 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.608 |                   62.048 |            1.55699  |
|   4 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             434.88  |                  234.624 |            1.85352  |
|   4 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             231.52  |                  137.12  |            1.68845  |
|   4 |  8192 |  8192 |          128 | torch.float8_e5m2   |              82.72  |                   54.208 |            1.52597  |
|   4 |  8192 | 10240 |          128 | torch.float8_e5m2   |              95.84  |                   62.368 |            1.53669  |
|   4 |  8192 | 57344 |          128 | torch.float8_e5m2   |             432.992 |                  238.464 |            1.81575  |
|   4 | 28672 |  8192 |          128 | torch.float8_e5m2   |             237.024 |                  136.032 |            1.74241  |
|   8 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.912 |                   56.512 |            1.46716  |
|   8 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.128 |                   62.64  |            1.53461  |
|   8 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             433.84  |                  243.68  |            1.78037  |
|   8 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             228.736 |                  143.072 |            1.59875  |
|   8 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.52  |                   56.736 |            1.47208  |
|   8 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.64  |                   63.072 |            1.53222  |
|   8 |  8192 | 57344 |          128 | torch.float8_e5m2   |             433.92  |                  243.296 |            1.78351  |
|   8 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.952 |                  140.128 |            1.66956  |
|  16 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.072 |                   58.816 |            1.4124   |
|  16 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              95.392 |                   66.944 |            1.42495  |
|  16 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             436.832 |                  244.352 |            1.78772  |
|  16 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             228.864 |                  152.64  |            1.49937  |
|  16 |  8192 |  8192 |          128 | torch.float8_e5m2   |              82.112 |                   60.416 |            1.35911  |
|  16 |  8192 | 10240 |          128 | torch.float8_e5m2   |              97.088 |                   66.464 |            1.46076  |
|  16 |  8192 | 57344 |          128 | torch.float8_e5m2   |             437.344 |                  247.728 |            1.76542  |
|  16 | 28672 |  8192 |          128 | torch.float8_e5m2   |             236.064 |                  155.072 |            1.52229  |
|  32 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.96  |                   62.72  |            1.35459  |
|  32 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              99.488 |                   69.248 |            1.43669  |
|  32 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             439.552 |                  281.328 |            1.56242  |
|  32 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             233.76  |                  168.064 |            1.3909   |
|  32 |  8192 |  8192 |          128 | torch.float8_e5m2   |              85.024 |                   62.672 |            1.35665  |
|  32 |  8192 | 10240 |          128 | torch.float8_e5m2   |              99.488 |                   69.088 |            1.44002  |
|  32 |  8192 | 57344 |          128 | torch.float8_e5m2   |             439.168 |                  275.616 |            1.59341  |
|  32 | 28672 |  8192 |          128 | torch.float8_e5m2   |             228.512 |                  167.888 |            1.3611   |
|  64 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.016 |                   76.768 |            1.06836  |
|  64 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.848 |                   92.032 |            1.05233  |
|  64 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             456.48  |                  515.248 |            0.885942 |
|  64 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             239.84  |                  244.864 |            0.979482 |
|  64 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.2   |                   77.28  |            1.0766   |
|  64 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.384 |                   91.968 |            1.04802  |
|  64 |  8192 | 57344 |          128 | torch.float8_e5m2   |             457.408 |                  506.368 |            0.903311 |
|  64 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.76  |                  243.904 |            0.95841  |
| 128 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.064 |                  141.696 |            0.59327  |
| 128 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.992 |                  170.688 |            0.591676 |
| 128 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             454.592 |                  991.936 |            0.458288 |
| 128 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             234.72  |                  483.04  |            0.485922 |
| 128 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.264 |                  141.616 |            0.587956 |
| 128 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.768 |                  171.488 |            0.58761  |
| 128 |  8192 | 57344 |          128 | torch.float8_e5m2   |             454.72  |                  982.976 |            0.462595 |
| 128 | 28672 |  8192 |          128 | torch.float8_e5m2   |             236.512 |                  479.872 |            0.492865 |
| 256 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              89.44  |                  285.504 |            0.313271 |
| 256 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             113.152 |                  342.144 |            0.330715 |
| 256 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             520.576 |                 1995.15  |            0.26092  |
| 256 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             244.48  |                  961.776 |            0.254196 |
| 256 |  8192 |  8192 |          128 | torch.float8_e5m2   |              89.056 |                  270.272 |            0.329505 |
| 256 |  8192 | 10240 |          128 | torch.float8_e5m2   |             111.424 |                  337.824 |            0.329829 |
| 256 |  8192 | 57344 |          128 | torch.float8_e5m2   |             527.168 |                 2036.4   |            0.258872 |
| 256 | 28672 |  8192 |          128 | torch.float8_e5m2   |             255.552 |                  961.888 |            0.265677 |
| 512 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             122.272 |                  556.512 |            0.219711 |
| 512 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             197.728 |                  674.688 |            0.293066 |
| 512 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             890.128 |                 4130.8   |            0.215486 |
| 512 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             412.416 |                 1883.71  |            0.218938 |
| 512 |  8192 |  8192 |          128 | torch.float8_e5m2   |             123.552 |                  542.624 |            0.227694 |
| 512 |  8192 | 10240 |          128 | torch.float8_e5m2   |             198.032 |                  678.016 |            0.292076 |
| 512 |  8192 | 57344 |          128 | torch.float8_e5m2   |             879.696 |                 3877.6   |            0.226866 |
| 512 | 28672 |  8192 |          128 | torch.float8_e5m2   |             412.256 |                 1936.1   |            0.212932```

@Degnel
Copy link
Author

Degnel commented May 1, 2025

Interesting, I see solid speedup vs fp16 for small M but it falls off as M increases, to the point where fp8 blockwise is actually 4x slower than fp16 for larger shapes, which is the opposite of what I would expect... this on H100s for context.

Maybe block sizes in autotuner are not optimal for shapes of this size? What do you think @Degnel @drisspg ? I suppose we can merge this is an initial start but we should profile this and fix the perf for larger shapes.

Thanks @danielvegamyhre. My first intuition was that it might be due to a hidden tensor copy, but the code looks fairly straightforward and I don’t immediately see anything problematic there. That said, the original paper does mention some know inefficiencies: “partial results will be copied from Tensor Cores to CUDA cores”. However, the sizes they benchmarked seem to be at least larger than the point where we start seeing the drop in efficiency. I’ll try to dig into this over the next few days and see if I can pinpoint what’s going on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants