Skip to content

Commit 69fc240

Browse files
Feat: Implementation of the DeepSeek blockwise quantization for fp8 tensors (#1763)
* Feat: Integration of DeepSeek's blockwise quantization - fp8 triton gemm - quant, dequant and linear utils - time & precision benchmarks - basic tests * Doc: init + linting + readme * Feat: adding triton dependency, adaptative testing dtype * Fix: - removing triton dependency - cleanning adaptative dtype * Fix: - fixing W4A8 quantization for cutlass kernel in precision benchmark - importing triton only if cuda available - setting a less harsh threshold for quant-dequant and for gemm kernel mm precision * Fix: - condition triton import in gemm - linting * Fix: triton pytest skip * Linting * Fix: - raising explicit error when running benchmark without cuda - merging quant, dequant and gemm code into one file - removing depricated int4/int8 comparison * Fix: - fix import in __init__.py and in blockwise_linear.py * Optim: fixing poor performance on large M values > the autotuner was optimizing based only on small M sizes at the beginning of the benchmark > added a `M_bucket` key to the autotuner to enable tuning based on similar M sizes > added `128` to the `BLOCK_SIZE_M` configuration, which improves performance for large M values > launcher now takes `block_size` into account (although using `block_size=128` is recommended for best performance) * Fix: skipping blockwise quant precision test for older versions of triton * Bench: incressing the bench range to m=8192
1 parent 0607aa1 commit 69fc240

File tree

6 files changed

+602
-0
lines changed

6 files changed

+602
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
10+
if torch.cuda.is_available():
11+
import pandas as pd
12+
from tqdm import tqdm
13+
from triton.testing import do_bench
14+
15+
from torchao.float8.float8_utils import compute_error
16+
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
17+
blockwise_fp8_gemm,
18+
fp8_blockwise_act_quant,
19+
fp8_blockwise_weight_quant,
20+
)
21+
from torchao.utils import is_sm_at_least_89
22+
else:
23+
raise RuntimeError("This benchmark is only avaible on CUDA hardware")
24+
25+
26+
def benchmark_microseconds(f, *args, warmup=25, rep=100):
27+
return (
28+
do_bench(lambda: f(*args), warmup=warmup, rep=rep, return_mode="median") * 1e3
29+
)
30+
31+
32+
def get_blockwise_problem(
33+
m: int, n: int, k: int, block_size: int, dtype: torch.dtype, device
34+
):
35+
assert n % block_size == 0 and k % block_size == 0, (
36+
"N and K dims must be divisible by block_size"
37+
)
38+
assert dtype in [
39+
torch.float8_e4m3fn,
40+
torch.float8_e5m2,
41+
], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
42+
dtype_max = torch.finfo(dtype).max
43+
A = (dtype_max * (2 * torch.rand(m, k, device=device) - 1)).to(dtype)
44+
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=device)
45+
B = (dtype_max * (2 * torch.rand(n, k, device=device) - 1)).to(dtype)
46+
B_scale = torch.randn(
47+
(n // block_size, k // block_size), dtype=torch.half, device=device
48+
)
49+
50+
return A, A_scale, B, B_scale
51+
52+
53+
def benchmark_latency(
54+
m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device
55+
):
56+
A_ref = torch.randn((m, k), dtype=torch.half, device=device)
57+
B_ref = torch.randn((n, k), dtype=torch.half, device=device)
58+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
59+
60+
A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size, dtype, device)
61+
blockwise_time = benchmark_microseconds(
62+
blockwise_fp8_gemm, A, A_scale, B, B_scale, block_size
63+
)
64+
65+
return {
66+
"m": m,
67+
"k": k,
68+
"n": n,
69+
"block_size": block_size,
70+
"dtype": dtype,
71+
"fp16_latency (ms)": fp16_time,
72+
"blockwise_latency (ms)": blockwise_time,
73+
"blockwise_speedup": fp16_time / blockwise_time,
74+
}
75+
76+
77+
def benchmark_precision(
78+
m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device
79+
):
80+
lin = torch.nn.Linear(k, n, False, device, torch.half)
81+
A = torch.randn((m, k), dtype=torch.half, device=device)
82+
W = lin.weight
83+
output = A @ W.T
84+
85+
A_q, A_s = fp8_blockwise_act_quant(A, block_size, dtype)
86+
W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
87+
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s, block_size)
88+
89+
return {
90+
"m": m,
91+
"k": k,
92+
"n": n,
93+
"block_size": block_size,
94+
"dtype": dtype,
95+
"error_blockwise (dB)": compute_error(output, output_blockwise),
96+
}
97+
98+
99+
if __name__ == "__main__" and torch.cuda.is_available():
100+
device = torch.device("cuda")
101+
k_vals = (8192, 8192, 8192, 28672)
102+
n_vals = (8192, 10240, 57344, 8192)
103+
block_size_vals = (128, 128, 128, 128)
104+
105+
latency_results = []
106+
precision_results = []
107+
108+
available_dtypes = (
109+
[torch.float8_e4m3fn, torch.float8_e5m2]
110+
if is_sm_at_least_89()
111+
else [torch.float8_e5m2]
112+
)
113+
for m in tqdm([1 << i for i in range(14)]):
114+
for dtype in available_dtypes:
115+
for n, k, block_size in zip(n_vals, k_vals, block_size_vals):
116+
latency_results.append(
117+
benchmark_latency(m, k, n, block_size, dtype, device)
118+
)
119+
precision_results.append(
120+
benchmark_precision(m, k, n, block_size, dtype, device)
121+
)
122+
123+
df_latency = pd.DataFrame(latency_results)
124+
df_precision = pd.DataFrame(precision_results)
125+
126+
df_latency.to_csv("blockwise_triton_latency_results.csv", index=False)
127+
df_precision.to_csv("blockwise_triton_precision_results.csv", index=False)
128+
129+
print(df_latency.to_markdown(index=False))
130+
print(df_precision.to_markdown(index=False))
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from packaging import version
11+
12+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
13+
14+
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
15+
blockwise_fp8_gemm,
16+
fp8_blockwise_act_quant,
17+
fp8_blockwise_weight_dequant,
18+
fp8_blockwise_weight_quant,
19+
)
20+
from torchao.utils import is_sm_at_least_89
21+
22+
BLOCKWISE_SIZE_MNK = [
23+
(2, 512, 128),
24+
(3, 2048, 2048),
25+
(4, 3584, 640),
26+
(13, 8704, 8576),
27+
(26, 18944, 1664),
28+
(67, 6656, 1408),
29+
]
30+
31+
32+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
33+
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
34+
@pytest.mark.parametrize(
35+
"dtype",
36+
[torch.float8_e4m3fn, torch.float8_e5m2]
37+
if is_sm_at_least_89()
38+
else [torch.float8_e5m2],
39+
)
40+
def test_blockwise_quant_dequant(_, N, K, dtype):
41+
x = torch.randn(N, K).cuda()
42+
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
43+
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
44+
error = torch.norm(x - x_reconstructed) / torch.norm(x)
45+
print(f"Relative Error: {error.item():.6f}")
46+
47+
assert error < 0.1, "Quant-Dequant error is too high"
48+
49+
50+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
51+
@pytest.mark.skipif(
52+
version.parse(triton.__version__) < version.parse("3.3.0"),
53+
reason="Triton version < 3.3.0, test skipped",
54+
)
55+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
56+
@pytest.mark.parametrize(
57+
"dtype",
58+
[torch.float8_e4m3fn, torch.float8_e5m2]
59+
if is_sm_at_least_89()
60+
else [torch.float8_e5m2],
61+
)
62+
def test_blockwise_fp8_gemm(M, N, K, dtype):
63+
A = torch.randn(M, K).cuda()
64+
B = torch.randn(N, K).cuda()
65+
C = A @ B.T
66+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
67+
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
68+
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
69+
error = torch.norm(C - C_q) / torch.norm(C)
70+
print(f"Relative Error: {error.item():.6f}")
71+
72+
assert error < 0.1, "Quantize gemm error is too high"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Blockwise Quantization Implementation
2+
3+
## Overview
4+
5+
This directory contains the implementation of blockwise quantization introduced by DeepSeek. The method involves quantizing activations and weight matrices in blocks of 128x1 and 128x128, respectively.
6+
7+
## Quantization Process
8+
9+
### Activation Quantization
10+
- Activations are quantized in blocks of size 128x1 using the FP8 format
11+
12+
### Weight Matrix Quantization
13+
- Weights are quantized in blocks of size 128x128 using the FP8 format
14+
15+
## Kernel Implementation in Triton
16+
17+
- The kernel for blockwise quantization is implemented using Triton
18+
- For now, the only supported types are: torch.float8_e4m3fn and torch.float8_e5m2
19+
20+
## Illustration
21+
22+
![Blockwise Quantization Illustration](https://arxiv.org/html/2412.19437v1/x7.png)
23+
24+
*Illustration of the blockwise quantization process.*
25+
26+
## Original Paper
27+
28+
For detailed motivations and technical specifications, please refer to the original paper:
29+
- [DeepSeek Blockwise Quantization Paper](https://arxiv.org/html/2412.19437v1)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .blockwise_linear import BlockwiseQuantLinear
2+
from .blockwise_quantization import (
3+
blockwise_fp8_gemm,
4+
fp8_blockwise_act_quant,
5+
fp8_blockwise_weight_dequant,
6+
fp8_blockwise_weight_quant,
7+
)
8+
9+
__all__ = [
10+
"blockwise_fp8_gemm",
11+
"BlockwiseQuantLinear",
12+
"fp8_blockwise_act_quant",
13+
"fp8_blockwise_weight_quant",
14+
"fp8_blockwise_weight_dequant",
15+
]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch import nn
9+
10+
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
11+
blockwise_fp8_gemm,
12+
fp8_blockwise_act_quant,
13+
)
14+
15+
16+
class BlockwiseQuantLinear(nn.Module):
17+
"""
18+
Custom linear layer with support for quantized weights and optional bias.
19+
20+
Args:
21+
in_features (int): Number of input features.
22+
out_features (int): Number of output features.
23+
bias (bool): Whether to include a bias term. Defaults to False.
24+
block_size (int): Block size for quantization. Defaults to 128.
25+
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
26+
"""
27+
28+
dtype = torch.bfloat16
29+
30+
def __init__(
31+
self,
32+
in_features: int,
33+
out_features: int,
34+
bias: bool = False,
35+
block_size: int = 128,
36+
dtype: torch.dtype = torch.float8_e4m3fn,
37+
):
38+
super().__init__()
39+
supported_dtypes = [
40+
torch.float8_e4m3fn,
41+
torch.float8_e5m2,
42+
]
43+
assert dtype in supported_dtypes, (
44+
f"Unsupported dtype: {dtype}. Supported dtypes: {supported_dtypes}"
45+
)
46+
scale_in_features = (in_features + block_size - 1) // block_size
47+
scale_out_features = (out_features + block_size - 1) // block_size
48+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
49+
self.weight.scale = self.scale = nn.Parameter(
50+
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
51+
)
52+
self.block_size = block_size
53+
self.dtype
54+
55+
if bias:
56+
self.bias = nn.Parameter(torch.empty(out_features))
57+
else:
58+
self.register_parameter("bias", None)
59+
60+
def forward(self, x: torch.Tensor) -> torch.Tensor:
61+
"""
62+
Forward pass for the custom linear layer.
63+
64+
Args:
65+
x (torch.Tensor): Input tensor.
66+
67+
Returns:
68+
torch.Tensor: Transformed tensor after linear computation.
69+
"""
70+
x, scale = fp8_blockwise_act_quant(x, self.block_size, self.dtype)
71+
y = blockwise_fp8_gemm(
72+
x, scale, self.weight, self.weight.scale, self.block_size
73+
)
74+
75+
if self.bias is not None:
76+
y += self.bias
77+
return y

0 commit comments

Comments
 (0)