Skip to content

Commit 337d264

Browse files
Fix:
- fixing W4A8 quantization for cutlass kernel in precision benchmark - importing triton only if cuda available - setting a less harsh threshold for quant-dequant and for gemm kernel mm precision
1 parent 8d68d45 commit 337d264

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pandas as pd
22
import torch
33
from tqdm import tqdm
4-
from triton.testing import do_bench
4+
5+
if torch.cuda.is_available():
6+
from triton.testing import do_bench
57

68
from torchao.float8.float8_utils import compute_error
79
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
@@ -10,9 +12,10 @@
1012
fp8_blockwise_act_quant,
1113
fp8_blockwise_weight_quant,
1214
)
15+
1316
from torchao.quantization.quant_api import (
14-
int8_dynamic_activation_int4_weight,
15-
quantize_,
17+
_int8_symm_per_token_reduced_range_quant_cutlass,
18+
_int4_symm_per_token_quant_cutlass,
1619
)
1720

1821
from torchao.utils import is_sm_at_least_89
@@ -38,9 +41,14 @@ def get_blockwise_problem(
3841
assert (
3942
n % block_size == 0 and k % block_size == 0
4043
), "N and K dims must be divisible by block_size"
41-
A = (448.0 * (2 * torch.rand(m, k, device=device) - 1)).to(dtype)
44+
assert dtype in [
45+
torch.float8_e4m3fn,
46+
torch.float8_e5m2,
47+
], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
48+
dtype_max = torch.finfo(dtype).max
49+
A = (dtype_max * (2 * torch.rand(m, k, device=device) - 1)).to(dtype)
4250
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=device)
43-
B = (448.0 * (2 * torch.rand(n, k, device=device) - 1)).to(dtype)
51+
B = (dtype_max * (2 * torch.rand(n, k, device=device) - 1)).to(dtype)
4452
B_scale = torch.randn(
4553
(n // block_size, k // block_size), dtype=torch.half, device=device
4654
)
@@ -89,8 +97,15 @@ def benchmark_precision(
8997
W_q, W_s = fp8_blockwise_weight_quant(W, block_size, dtype)
9098
output_blockwise = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)
9199

92-
quantize_(lin, int8_dynamic_activation_int4_weight())
93-
output_rowwise = lin(A)
100+
qact = _int8_symm_per_token_reduced_range_quant_cutlass(A)
101+
qweight = _int4_symm_per_token_quant_cutlass(W)
102+
output_rowwise = rowwise_scaled_linear_cutlass_s8s4(
103+
qact.tensor_impl.int_data,
104+
qact.tensor_impl.scale,
105+
qweight.tensor_impl.int_data,
106+
qweight.tensor_impl.scale,
107+
None,
108+
)
94109

95110
return {
96111
"m": m,

test/prototype/test_blockwise_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_blockwise_quant_dequant(_, N, K, dtype):
3535
error = torch.norm(x - x_reconstructed) / torch.norm(x)
3636
print(f"Relative Error: {error.item():.6f}")
3737

38-
assert error < 0.05, "Quant-Dequant error is too high"
38+
assert error < 0.1, "Quant-Dequant error is too high"
3939

4040

4141
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -56,4 +56,4 @@ def test_blockwise_fp8_gemm(M, N, K, dtype):
5656
error = torch.norm(C - C_q) / torch.norm(C)
5757
print(f"Relative Error: {error.item():.6f}")
5858

59-
assert error < 0.05, "Quantize gemm error is too high"
59+
assert error < 0.1, "Quantize gemm error is too high"

torchao/prototype/blockwise_fp8/blockwise_quantization.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Tuple
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
if torch.cuda.is_available():
6+
import triton
7+
import triton.language as tl
68

79

810
@triton.jit
@@ -50,6 +52,10 @@ def fp8_blockwise_act_quant(
5052
assert (
5153
x.size(-1) % block_size == 0
5254
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
55+
assert dtype in [
56+
torch.float8_e4m3fn,
57+
torch.float8_e5m2,
58+
], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
5359
y = torch.empty_like(x, dtype=dtype)
5460
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
5561
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
@@ -108,6 +114,10 @@ def fp8_blockwise_weight_quant(
108114
assert (
109115
x.size(0) % block_size == 0 and x.size(1) % block_size == 0
110116
), f"Both dimensions of x must be divisible by block_size (block_size={block_size})"
117+
assert dtype in [
118+
torch.float8_e4m3fn,
119+
torch.float8_e5m2,
120+
], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
111121
M, N = x.size()
112122
y = torch.empty_like(x, dtype=dtype)
113123
s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32)

0 commit comments

Comments
 (0)