Skip to content

Feat: Implementation of the DeepSeek blockwise quantization for fp8 tensors #1763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

the-tuning-machine
Copy link
Contributor

This PR is the first step towards addressing issue #1594. It includes the following implementations:

  • fp8 triton gemm for blockwise quantisation
  • quant, dequant and linear utilities
  • time & precision benchmarks
  • basic tests

If the code is validated, it would be great to bench it on H100.

Copy link

pytorch-bot bot commented Feb 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1763

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 10c9436 with merge base c9b9adc (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 22, 2025
@danielvegamyhre
Copy link
Contributor

Thanks for your work on this! I'll take a closer look next week.

cc @vkuzo @drisspg

@the-tuning-machine
Copy link
Contributor Author

Thanks for running the tests. I have two questions regarding the errors:

  • Where should I add Triton to allow the tests to run successfully without introducing unnecessary dependencies in dev-requirements.txt?
  • Does torchao provide any utility to check the available FP8 types for each gpu architecture?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 27, 2025

Thanks for running the tests. I have two questions regarding the errors:

  • Where should I add Triton to allow the tests to run successfully without introducing unnecessary dependencies in dev-requirements.txt?

Can you clarify what you mean? Are tests failing in CI due to a missing triton installation? That shouldn't be happening, please share the link/logs if so.

  • Does torchao provide any utility to check the available FP8 types for each gpu architecture?

We just use helpers which skip tests if GPU architecture is not at least SM 89:

def is_sm_at_least_89():

You can find examples in the float8 tests (example).

@danielvegamyhre danielvegamyhre self-assigned this Feb 27, 2025
@danielvegamyhre danielvegamyhre self-requested a review February 27, 2025 17:45
@the-tuning-machine
Copy link
Contributor Author

the-tuning-machine commented Feb 28, 2025

Can you clarify what you mean? Are tests failing in CI due to a missing triton installation? That shouldn't be happening, please share the link/logs if so.

Indeed, they are. It looks like only the CPU runs are failing. I presume that bitsandbytes might not install triton when no GPU is available (I might be missing something there). Here is an instance of a failing log:

https://github.com/pytorch/ao/actions/runs/13484452669/job/37730985419?pr=1763#step:14:1276

We just use helpers which skip tests if GPU architecture is not at least SM 89:

def is_sm_at_least_89():

You can find examples in the float8 tests (example).

Thank you for the hint, I've locally updated the code accordingly 👍

W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

quantize_(lin, int8_dynamic_activation_int4_weight())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is int8_dynamic_activation_int4_weight being used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank's for noticing it. I was aiming for a static W4A8 quantization and I overlooked that it was dynamic. I will try to address this within the week.

@danielvegamyhre
Copy link
Contributor

Can you clarify what you mean? Are tests failing in CI due to a missing triton installation? That shouldn't be happening, please share the link/logs if so.

Also @Degnel you should skip tests requiring triton if CUDA is not available.

@danielvegamyhre
Copy link
Contributor

@Degnel thanks for your work on this, i ran the tests and it looks like your blockwise fp8 gemm test is failing due to quantization error

@the-tuning-machine
Copy link
Contributor Author

@Degnel thanks for your work on this, i ran the tests and it looks like your blockwise fp8 gemm test is failing due to quantization error

Thanks for pointing that out! I had also noticed the issue, and I think I was just a bit too harsh with the threshold. I'll increase it to make it more reasonable. That said, I'll still double-check the calculations manually to ensure everything is mathematically correct.

@the-tuning-machine
Copy link
Contributor Author

the-tuning-machine commented Mar 13, 2025

@danielvegamyhre I believe that everything should be alright except for the PR Label Check (I'm not sure if I have the required rights to edit this). The test-mps-ops (macos-m1-stable) failed, but I think that the merge will fix it as it seems to be a newly introduced test.

@danielvegamyhre danielvegamyhre added the topic: new feature Use this tag if this PR adds a new feature label Mar 13, 2025
@the-tuning-machine
Copy link
Contributor Author

the-tuning-machine commented Mar 14, 2025

The test-mps-ops (macos-m1-stable) failed once again. I've seen other recent PRs both succeeding and failing this test (due to the same missing package 'importlib_metadata'). I don't think this is related to the code I wrote, but I might be missing something. Please, let me know if you have any insights.

@drisspg
Copy link
Contributor

drisspg commented Mar 14, 2025

the test mps is unrleated, re-running tests

@the-tuning-machine
Copy link
Contributor Author

the-tuning-machine commented Apr 21, 2025

It seems like the new PRs are not failing anymore due to the macOS tests. Maybe we should try to rerun it here :) @danielvegamyhre @drisspg

@drisspg
Copy link
Contributor

drisspg commented Apr 24, 2025

Sorry, could you do 1 more rebase to kick back off ci

- fp8 triton gemm
- quant, dequant and linear utils
- time & precision benchmarks
- basic tests
- removing triton dependency
- cleanning adaptative dtype
- fixing W4A8 quantization for cutlass kernel in precision benchmark
- importing triton only if cuda available
- setting a less harsh threshold for quant-dequant and for gemm kernel mm precision
- condition triton import in gemm
- linting
@the-tuning-machine the-tuning-machine force-pushed the feat/blockwise_fp8_quant_triton_gemm_ker branch from e8edea9 to e41457c Compare April 25, 2025 12:11
@the-tuning-machine
Copy link
Contributor Author

Sorry, could you do 1 more rebase to kick back off ci

No problem, it should be ok

@the-tuning-machine
Copy link
Contributor Author

Thank you @drisspg I've made the linting

@drisspg
Copy link
Contributor

drisspg commented Apr 26, 2025

cc @danielvegamyhre can you carry this across the finish line

Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress @Degnel and sorry for the delay! Just did another pass and left a couple comments.

W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

qact = _int8_symm_per_token_reduced_range_quant_cutlass(A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are int4 and int8 quant being used here?

Note these no longer exist in torchao (_int4_symm_per_token_quant_cutlass and _int8_symm_per_token_reduced_range_quant_cutlass) so they should be removed/changed but I'm unclear why int4/int8 quant is being used in the first place, can you clarify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be transparent, this code was written a few months ago, and I don't fully recall the rationale behind using those specific quantization modes. At the time, the goal was to establish a baseline for comparison across different quantization strategies, and int4/int8 quantization options seemed like a good starting point. Looking back, I agree it was probably a poor choice. I've removed the deprecated imports.

the-tuning-machine and others added 3 commits April 29, 2025 14:52
- raising explicit error when running benchmark without cuda
- merging quant, dequant and gemm code into one file
- removing depricated int4/int8 comparison
…:Degnel/ao into feat/blockwise_fp8_quant_triton_gemm_ker
@danielvegamyhre
Copy link
Contributor

@Degnel I get an import error when trying to run the benchmarks, can you take a look?

(torchtitan) [[email protected] ~/ao/benchmarks (feat/blockwise_fp8_quant_triton_gemm_ker)]$ python benchmark_blockwise_scaled_linear_triton.py 
Traceback (most recent call last):
  File "/home/danvm/ao/benchmarks/benchmark_blockwise_scaled_linear_triton.py", line 15, in <module>
    from torchao.prototype.blockwise_fp8.blockwise_quantization import (
    ...<3 lines>...
    )
  File "/home/danvm/ao/torchao/prototype/blockwise_fp8/__init__.py", line 1, in <module>
    from .blockwise_fp8_gemm_triton import blockwise_fp8_gemm
ModuleNotFoundError: No module named 'torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton'

- fix import in __init__.py and in blockwise_linear.py
@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Apr 30, 2025

Interesting, I see solid speedup vs fp16 for small M but it falls off as M increases, to the point where fp8 blockwise is actually 4x slower than fp16 for larger shapes, which is the opposite of what I would expect... this on H100s for context.

Maybe block sizes in autotuner are not optimal for shapes of this size? What do you think @Degnel @drisspg ? I suppose we can merge this is an initial start but we should profile this and fix the perf for larger shapes.

|   m |     k |     n |   block_size | dtype               |   fp16_latency (ms) |   blockwise_latency (ms) |   blockwise_speedup |
|----:|------:|------:|-------------:|:--------------------|--------------------:|-------------------------:|--------------------:|
|   1 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.144 |                   52.256 |            1.57195  |
|   1 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              95.936 |                   61.184 |            1.56799  |
|   1 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             429.472 |                  233.712 |            1.83761  |
|   1 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             245.44  |                  131.296 |            1.86936  |
|   1 |  8192 |  8192 |          128 | torch.float8_e5m2   |              82.208 |                   52.704 |            1.55981  |
|   1 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.192 |                   61.056 |            1.57547  |
|   1 |  8192 | 57344 |          128 | torch.float8_e5m2   |             430.944 |                  234.08  |            1.84101  |
|   1 | 28672 |  8192 |          128 | torch.float8_e5m2   |             244.736 |                  131.104 |            1.86673  |
|   2 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.048 |                   53.92  |            1.52166  |
|   2 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.448 |                   62.016 |            1.55521  |
|   2 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             432.16  |                  233.504 |            1.85076  |
|   2 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             228.544 |                  135.808 |            1.68285  |
|   2 |  8192 |  8192 |          128 | torch.float8_e5m2   |              81.984 |                   54.048 |            1.51687  |
|   2 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.352 |                   61.952 |            1.55527  |
|   2 |  8192 | 57344 |          128 | torch.float8_e5m2   |             432.544 |                  235.744 |            1.8348   |
|   2 | 28672 |  8192 |          128 | torch.float8_e5m2   |             227.328 |                  136.512 |            1.66526  |
|   4 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.816 |                   54.336 |            1.52415  |
|   4 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.608 |                   62.048 |            1.55699  |
|   4 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             434.88  |                  234.624 |            1.85352  |
|   4 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             231.52  |                  137.12  |            1.68845  |
|   4 |  8192 |  8192 |          128 | torch.float8_e5m2   |              82.72  |                   54.208 |            1.52597  |
|   4 |  8192 | 10240 |          128 | torch.float8_e5m2   |              95.84  |                   62.368 |            1.53669  |
|   4 |  8192 | 57344 |          128 | torch.float8_e5m2   |             432.992 |                  238.464 |            1.81575  |
|   4 | 28672 |  8192 |          128 | torch.float8_e5m2   |             237.024 |                  136.032 |            1.74241  |
|   8 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.912 |                   56.512 |            1.46716  |
|   8 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.128 |                   62.64  |            1.53461  |
|   8 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             433.84  |                  243.68  |            1.78037  |
|   8 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             228.736 |                  143.072 |            1.59875  |
|   8 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.52  |                   56.736 |            1.47208  |
|   8 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.64  |                   63.072 |            1.53222  |
|   8 |  8192 | 57344 |          128 | torch.float8_e5m2   |             433.92  |                  243.296 |            1.78351  |
|   8 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.952 |                  140.128 |            1.66956  |
|  16 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.072 |                   58.816 |            1.4124   |
|  16 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              95.392 |                   66.944 |            1.42495  |
|  16 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             436.832 |                  244.352 |            1.78772  |
|  16 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             228.864 |                  152.64  |            1.49937  |
|  16 |  8192 |  8192 |          128 | torch.float8_e5m2   |              82.112 |                   60.416 |            1.35911  |
|  16 |  8192 | 10240 |          128 | torch.float8_e5m2   |              97.088 |                   66.464 |            1.46076  |
|  16 |  8192 | 57344 |          128 | torch.float8_e5m2   |             437.344 |                  247.728 |            1.76542  |
|  16 | 28672 |  8192 |          128 | torch.float8_e5m2   |             236.064 |                  155.072 |            1.52229  |
|  32 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.96  |                   62.72  |            1.35459  |
|  32 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              99.488 |                   69.248 |            1.43669  |
|  32 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             439.552 |                  281.328 |            1.56242  |
|  32 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             233.76  |                  168.064 |            1.3909   |
|  32 |  8192 |  8192 |          128 | torch.float8_e5m2   |              85.024 |                   62.672 |            1.35665  |
|  32 |  8192 | 10240 |          128 | torch.float8_e5m2   |              99.488 |                   69.088 |            1.44002  |
|  32 |  8192 | 57344 |          128 | torch.float8_e5m2   |             439.168 |                  275.616 |            1.59341  |
|  32 | 28672 |  8192 |          128 | torch.float8_e5m2   |             228.512 |                  167.888 |            1.3611   |
|  64 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              82.016 |                   76.768 |            1.06836  |
|  64 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.848 |                   92.032 |            1.05233  |
|  64 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             456.48  |                  515.248 |            0.885942 |
|  64 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             239.84  |                  244.864 |            0.979482 |
|  64 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.2   |                   77.28  |            1.0766   |
|  64 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.384 |                   91.968 |            1.04802  |
|  64 |  8192 | 57344 |          128 | torch.float8_e5m2   |             457.408 |                  506.368 |            0.903311 |
|  64 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.76  |                  243.904 |            0.95841  |
| 128 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.064 |                  141.696 |            0.59327  |
| 128 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.992 |                  170.688 |            0.591676 |
| 128 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             454.592 |                  991.936 |            0.458288 |
| 128 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             234.72  |                  483.04  |            0.485922 |
| 128 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.264 |                  141.616 |            0.587956 |
| 128 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.768 |                  171.488 |            0.58761  |
| 128 |  8192 | 57344 |          128 | torch.float8_e5m2   |             454.72  |                  982.976 |            0.462595 |
| 128 | 28672 |  8192 |          128 | torch.float8_e5m2   |             236.512 |                  479.872 |            0.492865 |
| 256 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              89.44  |                  285.504 |            0.313271 |
| 256 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             113.152 |                  342.144 |            0.330715 |
| 256 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             520.576 |                 1995.15  |            0.26092  |
| 256 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             244.48  |                  961.776 |            0.254196 |
| 256 |  8192 |  8192 |          128 | torch.float8_e5m2   |              89.056 |                  270.272 |            0.329505 |
| 256 |  8192 | 10240 |          128 | torch.float8_e5m2   |             111.424 |                  337.824 |            0.329829 |
| 256 |  8192 | 57344 |          128 | torch.float8_e5m2   |             527.168 |                 2036.4   |            0.258872 |
| 256 | 28672 |  8192 |          128 | torch.float8_e5m2   |             255.552 |                  961.888 |            0.265677 |
| 512 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             122.272 |                  556.512 |            0.219711 |
| 512 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             197.728 |                  674.688 |            0.293066 |
| 512 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             890.128 |                 4130.8   |            0.215486 |
| 512 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             412.416 |                 1883.71  |            0.218938 |
| 512 |  8192 |  8192 |          128 | torch.float8_e5m2   |             123.552 |                  542.624 |            0.227694 |
| 512 |  8192 | 10240 |          128 | torch.float8_e5m2   |             198.032 |                  678.016 |            0.292076 |
| 512 |  8192 | 57344 |          128 | torch.float8_e5m2   |             879.696 |                 3877.6   |            0.226866 |
| 512 | 28672 |  8192 |          128 | torch.float8_e5m2   |             412.256 |                 1936.1   |            0.212932```

@the-tuning-machine
Copy link
Contributor Author

the-tuning-machine commented May 1, 2025

Interesting, I see solid speedup vs fp16 for small M but it falls off as M increases, to the point where fp8 blockwise is actually 4x slower than fp16 for larger shapes, which is the opposite of what I would expect... this on H100s for context.

Maybe block sizes in autotuner are not optimal for shapes of this size? What do you think @Degnel @drisspg ? I suppose we can merge this is an initial start but we should profile this and fix the perf for larger shapes.

Thanks @danielvegamyhre. My first intuition was that it might be due to a hidden tensor copy, but the code looks fairly straightforward and I don’t immediately see anything problematic there. That said, the original paper does mention some know inefficiencies: “partial results will be copied from Tensor Cores to CUDA cores”. However, the sizes they benchmarked seem to be at least larger than the point where we start seeing the drop in efficiency. I’ll try to dig into this over the next few days and see if I can pinpoint what’s going on.

the-tuning-machine and others added 3 commits May 7, 2025 12:08
> the autotuner was optimizing based only on small M sizes at the beginning of the benchmark
	> added a `M_bucket` key to the autotuner to enable tuning based on similar M sizes
> added `128` to the `BLOCK_SIZE_M` configuration, which improves performance for large M values
> launcher now takes `block_size` into account (although using `block_size=128`
  is recommended for best performance)
…:Degnel/ao into feat/blockwise_fp8_quant_triton_gemm_ker
@the-tuning-machine
Copy link
Contributor Author

the-tuning-machine commented May 7, 2025

@danielvegamyhre @drisspg I finally got the time to dive into the problem.
The significant performance issue for large M values was due to the autotuner not taking M into account when tuning the parameters. It was optimizing for the small values appreaing ealry in the benchmark. To address this, an M_BUCKET key was added to the autotuner to allow tuning based on similar ranges of M. In theory, we could have tuned for every distinct M, but since M corresponds to the context length in an LLM and can vary at each call, I guess it’s better to adopt a more parsimonious approach.

On an L4, the default DeepSeek parameters are already quite solid. I eventually found out that adding 128 to the BLOCK_SIZE_M configuration, leads to some performance gains for large M (before that change, computing the matmul in fp16 was actually faster than the optimized path for M = 512 even with the key fix).
Let me know if any other changes are needed.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented May 8, 2025

@Degnel Thanks for fixing the autotune config for M and verifying the perf. I want to run everything locally to double check but I'm getting an import error again trying to run the benchmark. This happened previously and you fixed it, did the fix get undone somehow?

Traceback (most recent call last):
  File "<python-input-1>", line 1, in <module>
    from torchao.prototype.blockwise_fp8 import blockwise_fp8_gemm
ModuleNotFoundError: No module named 'torchao.prototype.blockwise_fp8'

@the-tuning-machine
Copy link
Contributor Author

@danielvegamyhre I just checked, and I don’t see 'from torchao.prototype.blockwise_fp8 import blockwise_fp8_gemm' anywhere in the latest version of the code (https://github.com/pytorch/ao/pull/1763/files). Maybe you're using an older version? Let me know if that’s not the case.

If needed, the correct import should be: 'from torchao.prototype.blockwise_fp8.blockwise_quantization import blockwise_fp8_gemm'.

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre I just checked, and I don’t see 'from torchao.prototype.blockwise_fp8 import blockwise_fp8_gemm' anywhere in the latest version of the code (https://github.com/pytorch/ao/pull/1763/files). Maybe you're using an older version? Let me know if that’s not the case.

If needed, the correct import should be: 'from torchao.prototype.blockwise_fp8.blockwise_quantization import blockwise_fp8_gemm'.

Re-checked out the PR and am getting a different import error when running the benchmark now:

(torchtitan) [[email protected] ~/ao/benchmarks (feat/blockwise_fp8_quant_triton_gemm_ker)]$ python benchmark_blockwise_scaled_linear_triton.py 
Traceback (most recent call last):
  File "/home/danvm/ao/benchmarks/benchmark_blockwise_scaled_linear_triton.py", line 16, in <module>
    from torchao.prototype.blockwise_fp8.blockwise_quantization import (
    ...<3 lines>...
    )
ModuleNotFoundError: No module named 'torchao.prototype.blockwise_fp8'

This particular line does appear in the latest PR.

thanks for sticking with this, looks good to me once we get this fixed!

@the-tuning-machine
Copy link
Contributor Author

(torchtitan) [[email protected] ~/ao/benchmarks (feat/blockwise_fp8_quant_triton_gemm_ker)]$ python benchmark_blockwise_scaled_linear_triton.py 
Traceback (most recent call last):
  File "/home/danvm/ao/benchmarks/benchmark_blockwise_scaled_linear_triton.py", line 16, in <module>
    from torchao.prototype.blockwise_fp8.blockwise_quantization import (
    ...<3 lines>...
    )
ModuleNotFoundError: No module named 'torchao.prototype.blockwise_fp8'

Hi @danielvegamyhre, I think that you should run 'pip install -e .' at the root of your project. It looks like Python does not recognize the newly added module.
Let me know how it goes, and I can take another look if needed.

@danielvegamyhre
Copy link
Contributor

Latest benchmarks,

|   m |     k |     n |   block_size | dtype               |   fp16_latency (ms) |   blockwise_latency (ms) |   blockwise_speedup |
|----:|------:|------:|-------------:|:--------------------|--------------------:|-------------------------:|--------------------:|
|   1 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.096 |                   52.352 |            1.60636  |
|   1 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              99.616 |                   61.344 |            1.62389  |
|   1 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             436.768 |                  234.016 |            1.8664   |
|   1 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             233.68  |                  131.456 |            1.77763  |
|   1 |  8192 |  8192 |          128 | torch.float8_e5m2   |              84.992 |                   52.832 |            1.60872  |
|   1 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.512 |                   61.024 |            1.64709  |
|   1 |  8192 | 57344 |          128 | torch.float8_e5m2   |             441.6   |                  234.112 |            1.88628  |
|   1 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.44  |                  131.2   |            1.77927  |
|   2 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.584 |                   53.728 |            1.55569  |
|   2 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.384 |                   61.76  |            1.62539  |
|   2 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             432.672 |                  233.664 |            1.85168  |
|   2 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             233.92  |                  133.664 |            1.75006  |
|   2 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.424 |                   53.664 |            1.55456  |
|   2 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.688 |                   61.792 |            1.62947  |
|   2 |  8192 | 57344 |          128 | torch.float8_e5m2   |             432.416 |                  235.088 |            1.83938  |
|   2 | 28672 |  8192 |          128 | torch.float8_e5m2   |             234.016 |                  136.736 |            1.71144  |
|   4 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.32  |                   53.088 |            1.58831  |
|   4 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.704 |                   61.92  |            1.62636  |
|   4 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             433.024 |                  235.104 |            1.84184  |
|   4 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             234.464 |                  134.592 |            1.74204  |
|   4 |  8192 |  8192 |          128 | torch.float8_e5m2   |              84.032 |                   53.568 |            1.5687   |
|   4 |  8192 | 10240 |          128 | torch.float8_e5m2   |              99.968 |                   62.064 |            1.61072  |
|   4 |  8192 | 57344 |          128 | torch.float8_e5m2   |             433.376 |                  238.432 |            1.81761  |
|   4 | 28672 |  8192 |          128 | torch.float8_e5m2   |             235.744 |                  134.272 |            1.75572  |
|   8 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.776 |                   53.632 |            1.56205  |
|   8 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.864 |                   63.072 |            1.59919  |
|   8 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             439.296 |                  239.136 |            1.83701  |
|   8 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             234.256 |                  135.2   |            1.73266  |
|   8 |  8192 |  8192 |          128 | torch.float8_e5m2   |              84.288 |                   53.888 |            1.56413  |
|   8 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.672 |                   63.104 |            1.59533  |
|   8 |  8192 | 57344 |          128 | torch.float8_e5m2   |             439.6   |                  238.976 |            1.83952  |
|   8 | 28672 |  8192 |          128 | torch.float8_e5m2   |             235.2   |                  135.552 |            1.73513  |
|  16 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.904 |                   53.792 |            1.55979  |
|  16 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              99.552 |                   63.264 |            1.5736   |
|  16 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             444.224 |                  245.408 |            1.81014  |
|  16 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             235.616 |                  133.888 |            1.7598   |
|  16 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.648 |                   53.696 |            1.55781  |
|  16 |  8192 | 10240 |          128 | torch.float8_e5m2   |             101.344 |                   63.296 |            1.60111  |
|  16 |  8192 | 57344 |          128 | torch.float8_e5m2   |             444.096 |                  244.496 |            1.81637  |
|  16 | 28672 |  8192 |          128 | torch.float8_e5m2   |             235.424 |                  133.216 |            1.76724  |
|  32 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.968 |                   53.344 |            1.57409  |
|  32 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             102.88  |                   63.472 |            1.62087  |
|  32 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             441.6   |                  242.608 |            1.82022  |
|  32 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             237.056 |                  133.44  |            1.7765   |
|  32 |  8192 |  8192 |          128 | torch.float8_e5m2   |              86.24  |                   53.408 |            1.61474  |
|  32 |  8192 | 10240 |          128 | torch.float8_e5m2   |             102.304 |                   63.52  |            1.61058  |
|  32 |  8192 | 57344 |          128 | torch.float8_e5m2   |             439.936 |                  244.672 |            1.79806  |
|  32 | 28672 |  8192 |          128 | torch.float8_e5m2   |             238.304 |                  134.048 |            1.77775  |
|  64 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              86.144 |                   53.824 |            1.60048  |
|  64 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              93.728 |                   64.064 |            1.46304  |
|  64 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             471.6   |                  244.56  |            1.92836  |
|  64 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             240.16  |                  136.48  |            1.75967  |
|  64 |  8192 |  8192 |          128 | torch.float8_e5m2   |              86.144 |                   54.048 |            1.59384  |
|  64 |  8192 | 10240 |          128 | torch.float8_e5m2   |              93.472 |                   64.16  |            1.45686  |
|  64 |  8192 | 57344 |          128 | torch.float8_e5m2   |             470.592 |                  244.192 |            1.92714  |
|  64 | 28672 |  8192 |          128 | torch.float8_e5m2   |             241.568 |                  136.544 |            1.76916  |
| 128 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              91.168 |                   57.28  |            1.59162  |
| 128 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.896 |                   68     |            1.42494  |
| 128 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             449.472 |                  281.344 |            1.59759  |
| 128 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             246.496 |                  146.08  |            1.6874   |
| 128 |  8192 |  8192 |          128 | torch.float8_e5m2   |              89.312 |                   57.568 |            1.55142  |
| 128 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.672 |                   68.96  |            1.40186  |
| 128 |  8192 | 57344 |          128 | torch.float8_e5m2   |             449.92  |                  282.016 |            1.59537  |
| 128 | 28672 |  8192 |          128 | torch.float8_e5m2   |             246.736 |                  147.904 |            1.66822  |
| 256 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              85.92  |                   62.432 |            1.37622  |
| 256 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             101.44  |                  104.736 |            0.96853  |
| 256 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             478.56  |                  433.856 |            1.10304  |
| 256 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             258.24  |                  207.488 |            1.2446   |
| 256 |  8192 |  8192 |          128 | torch.float8_e5m2   |              86.24  |                   62.144 |            1.38774  |
| 256 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.928 |                  103.936 |            0.971059 |
| 256 |  8192 | 57344 |          128 | torch.float8_e5m2   |             474.56  |                  429.024 |            1.10614  |
| 256 | 28672 |  8192 |          128 | torch.float8_e5m2   |             259.872 |                  206.688 |            1.25732  |
| 512 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             117.056 |                  111.36  |            1.05115  |
| 512 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             150.592 |                  161.376 |            0.933175 |
| 512 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             826.608 |                  828.672 |            0.997509 |
| 512 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             405.44  |                  392.512 |            1.03294  |
| 512 |  8192 |  8192 |          128 | torch.float8_e5m2   |             119.072 |                  110.944 |            1.07326  |
| 512 |  8192 | 10240 |          128 | torch.float8_e5m2   |             152.144 |                  160.864 |            0.945793 |
| 512 |  8192 | 57344 |          128 | torch.float8_e5m2   |             825.856 |                  841.728 |            0.981144 |
| 512 | 28672 |  8192 |          128 | torch.float8_e5m2   |             398.656 |                  390.336 |            1.02131  |
|   m |     k |     n |   block_size | dtype               |   error_blockwise (dB) |
|----:|------:|------:|-------------:|:--------------------|-----------------------:|
|   1 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8681 |
|   1 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8298 |
|   1 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.774  |
|   1 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8339 |
|   1 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.9387 |
|   1 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.7398 |
|   1 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8115 |
|   1 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8887 |
|   2 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8904 |
|   2 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.746  |
|   2 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8572 |
|   2 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.82   |
|   2 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8174 |
|   2 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8006 |
|   2 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.837  |
|   2 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.7051 |
|   4 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8256 |
|   4 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8403 |
|   4 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.7766 |
|   4 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8333 |
|   4 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.7843 |
|   4 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8276 |
|   4 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8254 |
|   4 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8338 |
|   8 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.7693 |
|   8 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8375 |
|   8 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8292 |
|   8 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.7461 |
|   8 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8545 |
|   8 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8707 |
|   8 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8205 |
|   8 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8111 |
|  16 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8097 |
|  16 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8552 |
|  16 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8185 |
|  16 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.805  |
|  16 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.7863 |
|  16 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8112 |
|  16 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8025 |
|  16 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8063 |
|  32 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8041 |
|  32 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8168 |
|  32 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8386 |
|  32 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8064 |
|  32 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8266 |
|  32 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8116 |
|  32 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8181 |
|  32 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8257 |
|  64 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8238 |
|  64 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8131 |
|  64 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8228 |
|  64 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8164 |
|  64 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8303 |
|  64 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.7985 |
|  64 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8227 |
|  64 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8289 |
| 128 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8135 |
| 128 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.7909 |
| 128 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8222 |
| 128 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8236 |
| 128 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8158 |
| 128 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8262 |
| 128 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8082 |
| 128 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8173 |
| 256 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.8256 |
| 256 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8151 |
| 256 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8237 |
| 256 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8109 |
| 256 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8123 |
| 256 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8146 |
| 256 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8149 |
| 256 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8057 |
| 512 |  8192 |  8192 |          128 | torch.float8_e4m3fn |                28.813  |
| 512 |  8192 | 10240 |          128 | torch.float8_e4m3fn |                28.8236 |
| 512 |  8192 | 57344 |          128 | torch.float8_e4m3fn |                28.8235 |
| 512 | 28672 |  8192 |          128 | torch.float8_e4m3fn |                28.8219 |
| 512 |  8192 |  8192 |          128 | torch.float8_e5m2   |                22.8135 |
| 512 |  8192 | 10240 |          128 | torch.float8_e5m2   |                22.8197 |
| 512 |  8192 | 57344 |          128 | torch.float8_e5m2   |                22.8099 |
| 512 | 28672 |  8192 |          128 | torch.float8_e5m2   |                22.8058 |

@danielvegamyhre
Copy link
Contributor

Benchmarks on H100, looks like for high values of M we get a slight slowdown of a few % at most but nothing like the 0.38 observed on A100 by @Degnel. We can include these in the prototype README as a next step. Thanks for adding this @Degnel!

|    m |     k |     n |   block_size | dtype               |   fp16_latency (ms) |   blockwise_latency (ms) |   blockwise_speedup |
|-----:|------:|------:|-------------:|:--------------------|--------------------:|-------------------------:|--------------------:|
|    1 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.744 |                   52.224 |            1.60355  |
|    1 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              99.52  |                   61.12  |            1.62827  |
|    1 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             436.608 |                  234     |            1.86585  |
|    1 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             233.568 |                  131.168 |            1.78068  |
|    1 |  8192 |  8192 |          128 | torch.float8_e5m2   |              84.896 |                   52.736 |            1.60983  |
|    1 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.224 |                   60.96  |            1.64409  |
|    1 |  8192 | 57344 |          128 | torch.float8_e5m2   |             441.152 |                  233.968 |            1.88552  |
|    1 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.28  |                  130.816 |            1.78327  |
|    2 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.392 |                   53.664 |            1.55397  |
|    2 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.192 |                   61.632 |            1.62565  |
|    2 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             432.384 |                  233.664 |            1.85045  |
|    2 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             233.648 |                  133.6   |            1.74886  |
|    2 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.232 |                   53.6   |            1.55284  |
|    2 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.608 |                   61.664 |            1.63155  |
|    2 |  8192 | 57344 |          128 | torch.float8_e5m2   |             432.32  |                  235.152 |            1.83847  |
|    2 | 28672 |  8192 |          128 | torch.float8_e5m2   |             233.824 |                  136.256 |            1.71606  |
|    4 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              84.16  |                   52.928 |            1.59008  |
|    4 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.544 |                   61.728 |            1.62882  |
|    4 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             432.768 |                  234.944 |            1.842    |
|    4 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             234.432 |                  134.432 |            1.74387  |
|    4 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.872 |                   53.408 |            1.5704   |
|    4 |  8192 | 10240 |          128 | torch.float8_e5m2   |              99.84  |                   62.24  |            1.60411  |
|    4 |  8192 | 57344 |          128 | torch.float8_e5m2   |             433.376 |                  238.272 |            1.81883  |
|    4 | 28672 |  8192 |          128 | torch.float8_e5m2   |             235.584 |                  134.08  |            1.75704  |
|    8 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.648 |                   53.472 |            1.56433  |
|    8 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             100.704 |                   62.432 |            1.61302  |
|    8 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             439.104 |                  238.208 |            1.84336  |
|    8 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             234.272 |                  135.072 |            1.73442  |
|    8 |  8192 |  8192 |          128 | torch.float8_e5m2   |              84.128 |                   53.728 |            1.56581  |
|    8 |  8192 | 10240 |          128 | torch.float8_e5m2   |             100.512 |                   62.976 |            1.59604  |
|    8 |  8192 | 57344 |          128 | torch.float8_e5m2   |             439.36  |                  238.496 |            1.84221  |
|    8 | 28672 |  8192 |          128 | torch.float8_e5m2   |             235.04  |                  135.424 |            1.73559  |
|   16 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.808 |                   53.664 |            1.56172  |
|   16 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              99.584 |                   63.104 |            1.57809  |
|   16 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             444     |                  244.192 |            1.81824  |
|   16 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             235.52  |                  133.792 |            1.76034  |
|   16 |  8192 |  8192 |          128 | torch.float8_e5m2   |              83.488 |                   53.568 |            1.55854  |
|   16 |  8192 | 10240 |          128 | torch.float8_e5m2   |             101.216 |                   63.232 |            1.60071  |
|   16 |  8192 | 57344 |          128 | torch.float8_e5m2   |             444.608 |                  245.936 |            1.80782  |
|   16 | 28672 |  8192 |          128 | torch.float8_e5m2   |             235.36  |                  133.152 |            1.7676   |
|   32 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              83.872 |                   53.312 |            1.57323  |
|   32 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             102.688 |                   63.264 |            1.62317  |
|   32 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             441.792 |                  243.04  |            1.81777  |
|   32 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             237.12  |                  133.632 |            1.77443  |
|   32 |  8192 |  8192 |          128 | torch.float8_e5m2   |              86.08  |                   53.216 |            1.61756  |
|   32 |  8192 | 10240 |          128 | torch.float8_e5m2   |             102.032 |                   63.2   |            1.61443  |
|   32 |  8192 | 57344 |          128 | torch.float8_e5m2   |             439.168 |                  245.184 |            1.79118  |
|   32 | 28672 |  8192 |          128 | torch.float8_e5m2   |             238.016 |                  134.336 |            1.7718   |
|   64 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              85.888 |                   53.632 |            1.60143  |
|   64 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              93.632 |                   63.936 |            1.46446  |
|   64 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             471.44  |                  245.2   |            1.92268  |
|   64 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             240     |                  137.424 |            1.74642  |
|   64 |  8192 |  8192 |          128 | torch.float8_e5m2   |              85.984 |                   54.016 |            1.59182  |
|   64 |  8192 | 10240 |          128 | torch.float8_e5m2   |              93.376 |                   64.032 |            1.45827  |
|   64 |  8192 | 57344 |          128 | torch.float8_e5m2   |             471.36  |                  244.576 |            1.92725  |
|   64 | 28672 |  8192 |          128 | torch.float8_e5m2   |             242.4   |                  136.096 |            1.7811   |
|  128 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              91.008 |                   57.184 |            1.59149  |
|  128 |  8192 | 10240 |          128 | torch.float8_e4m3fn |              96.608 |                   67.936 |            1.42204  |
|  128 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             449.6   |                  292.48  |            1.5372   |
|  128 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             247.84  |                  147.232 |            1.68333  |
|  128 |  8192 |  8192 |          128 | torch.float8_e5m2   |              89.152 |                   57.248 |            1.55729  |
|  128 |  8192 | 10240 |          128 | torch.float8_e5m2   |              96.64  |                   68.784 |            1.40498  |
|  128 |  8192 | 57344 |          128 | torch.float8_e5m2   |             450.048 |                  284.16  |            1.58378  |
|  128 | 28672 |  8192 |          128 | torch.float8_e5m2   |             246.88  |                  148.064 |            1.66739  |
|  256 |  8192 |  8192 |          128 | torch.float8_e4m3fn |              85.984 |                   62.368 |            1.37866  |
|  256 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             101.216 |                  104.896 |            0.964918 |
|  256 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             477.984 |                  452.832 |            1.05554  |
|  256 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             260.224 |                  215.392 |            1.20814  |
|  256 |  8192 |  8192 |          128 | torch.float8_e5m2   |              86.432 |                   62.048 |            1.39299  |
|  256 |  8192 | 10240 |          128 | torch.float8_e5m2   |             101.024 |                  103.904 |            0.972282 |
|  256 |  8192 | 57344 |          128 | torch.float8_e5m2   |             475.568 |                  433.792 |            1.0963   |
|  256 | 28672 |  8192 |          128 | torch.float8_e5m2   |             261.824 |                  207.968 |            1.25896  |
|  512 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             117.952 |                  112.992 |            1.0439   |
|  512 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             151.504 |                  166.08  |            0.912235 |
|  512 |  8192 | 57344 |          128 | torch.float8_e4m3fn |             836.848 |                  881.312 |            0.949548 |
|  512 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             442.528 |                  402.464 |            1.09955  |
|  512 |  8192 |  8192 |          128 | torch.float8_e5m2   |             121.184 |                  114.592 |            1.05753  |
|  512 |  8192 | 10240 |          128 | torch.float8_e5m2   |             151.424 |                  163.296 |            0.927298 |
|  512 |  8192 | 57344 |          128 | torch.float8_e5m2   |             837.312 |                  873.664 |            0.958391 |
|  512 | 28672 |  8192 |          128 | torch.float8_e5m2   |             437.664 |                  400.928 |            1.09163  |
| 1024 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             227.008 |                  224.384 |            1.01169  |
| 1024 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             289.28  |                  283.872 |            1.01905  |
| 1024 |  8192 | 57344 |          128 | torch.float8_e4m3fn |            1672.13  |                 1673.34  |            0.999273 |
| 1024 | 28672 |  8192 |          128 | torch.float8_e4m3fn |             800     |                  769.152 |            1.04011  |
| 1024 |  8192 |  8192 |          128 | torch.float8_e5m2   |             224.48  |                  223.456 |            1.00458  |
| 1024 |  8192 | 10240 |          128 | torch.float8_e5m2   |             289.408 |                  283.424 |            1.02111  |
| 1024 |  8192 | 57344 |          128 | torch.float8_e5m2   |            1649.58  |                 1626.88  |            1.01396  |
| 1024 | 28672 |  8192 |          128 | torch.float8_e5m2   |             805.392 |                  768.416 |            1.04812  |
| 2048 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             449.344 |                  458.272 |            0.980518 |
| 2048 |  8192 | 10240 |          128 | torch.float8_e4m3fn |             569.888 |                  586.224 |            0.972134 |
| 2048 |  8192 | 57344 |          128 | torch.float8_e4m3fn |            3275.84  |                 3251.9   |            1.00736  |
| 2048 | 28672 |  8192 |          128 | torch.float8_e4m3fn |            1614.37  |                 1555.68  |            1.03772  |
| 2048 |  8192 |  8192 |          128 | torch.float8_e5m2   |             450.624 |                  461.712 |            0.975985 |
| 2048 |  8192 | 10240 |          128 | torch.float8_e5m2   |             575.36  |                  582.016 |            0.988564 |
| 2048 |  8192 | 57344 |          128 | torch.float8_e5m2   |            3363.3   |                 3213.31  |            1.04668  |
| 2048 | 28672 |  8192 |          128 | torch.float8_e5m2   |            1574.32  |                 1525.66  |            1.03189  |
| 4096 |  8192 |  8192 |          128 | torch.float8_e4m3fn |             915.216 |                  964.592 |            0.948812 |
| 4096 |  8192 | 10240 |          128 | torch.float8_e4m3fn |            1157.18  |                 1196.42  |            0.967209 |
| 4096 |  8192 | 57344 |          128 | torch.float8_e4m3fn |            6409.98  |                 6638.3   |            0.965606 |
| 4096 | 28672 |  8192 |          128 | torch.float8_e4m3fn |            3173.76  |                 3247.23  |            0.977374 |
| 4096 |  8192 |  8192 |          128 | torch.float8_e5m2   |             898.432 |                  949.36  |            0.946355 |
| 4096 |  8192 | 10240 |          128 | torch.float8_e5m2   |            1170.62  |                 1188.45  |            0.985002 |
| 4096 |  8192 | 57344 |          128 | torch.float8_e5m2   |            6751.25  |                 6573.71  |            1.02701  |
| 4096 | 28672 |  8192 |          128 | torch.float8_e5m2   |            3155.9   |                 3179.38  |            0.992617 |
| 8192 |  8192 |  8192 |          128 | torch.float8_e4m3fn |            1868.64  |                 2022.27  |            0.92403  |
| 8192 |  8192 | 10240 |          128 | torch.float8_e4m3fn |            2336.26  |                 2621.18  |            0.891298 |
| 8192 |  8192 | 57344 |          128 | torch.float8_e4m3fn |           13004     |                13990.6   |            0.929482 |
| 8192 | 28672 |  8192 |          128 | torch.float8_e4m3fn |            6781.49  |                 6722.82  |            1.00873  |
| 8192 |  8192 |  8192 |          128 | torch.float8_e5m2   |            1865.25  |                 1983.23  |            0.940509 |
| 8192 |  8192 | 10240 |          128 | torch.float8_e5m2   |            2296.66  |                 2523.1   |            0.91025  |
| 8192 |  8192 | 57344 |          128 | torch.float8_e5m2   |           13170.9   |                14029.6   |            0.938792 |
| 8192 | 28672 |  8192 |          128 | torch.float8_e5m2   |            6688.51  |                 6699.65  |            0.998338 |

@danielvegamyhre danielvegamyhre merged commit 69fc240 into pytorch:main May 12, 2025
17 of 18 checks passed
@the-tuning-machine
Copy link
Contributor Author

Thanks a lot @danielvegamyhre for your assistance and for running the benchmarks on the H100! Would you like me to go ahead and add the blockwise quantization to the TorchAO API now, or do you think it's too early given the current performance results? Another next step could be to try make a bench on a QAT task (this could be an occasion to update the README). Let me know what you think would be relevant.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Jun 24, 2025

Thanks a lot @danielvegamyhre for your assistance and for running the benchmarks on the H100! Would you like me to go ahead and add the blockwise quantization to the TorchAO API now, or do you think it's too early given the current performance results? Another next step could be to try make a bench on a QAT task (this could be an occasion to update the README). Let me know what you think would be relevant.

Sorry for the late reply, I just saw this - I think a good next step would be to make float8 blockwise linear differentiable via an autograd func so it can be used for training, and adding some training benchmarks. Here is an example with triton kernels that uses autograd functions to make it differentiable for training. Do you want to work on this? I would be excited to land this in torchao

@the-tuning-machine the-tuning-machine deleted the feat/blockwise_fp8_quant_triton_gemm_ker branch June 28, 2025 06:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants