1
1
import pandas as pd
2
2
import torch
3
3
from tqdm import tqdm
4
- from triton .testing import do_bench
4
+
5
+ if torch .cuda .is_available ():
6
+ from triton .testing import do_bench
5
7
6
8
from torchao .float8 .float8_utils import compute_error
7
9
from torchao .ops import rowwise_scaled_linear_cutlass_s8s4
10
12
fp8_blockwise_act_quant ,
11
13
fp8_blockwise_weight_quant ,
12
14
)
15
+
13
16
from torchao .quantization .quant_api import (
14
- int8_dynamic_activation_int4_weight ,
15
- quantize_ ,
17
+ _int8_symm_per_token_reduced_range_quant_cutlass ,
18
+ _int4_symm_per_token_quant_cutlass ,
16
19
)
17
20
18
21
from torchao .utils import is_sm_at_least_89
@@ -38,9 +41,14 @@ def get_blockwise_problem(
38
41
assert (
39
42
n % block_size == 0 and k % block_size == 0
40
43
), "N and K dims must be divisible by block_size"
41
- A = (448.0 * (2 * torch .rand (m , k , device = device ) - 1 )).to (dtype )
44
+ assert dtype in [
45
+ torch .float8_e4m3fn ,
46
+ torch .float8_e5m2 ,
47
+ ], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
48
+ dtype_max = torch .finfo (dtype ).max
49
+ A = (dtype_max * (2 * torch .rand (m , k , device = device ) - 1 )).to (dtype )
42
50
A_scale = torch .randn ((m , k // block_size ), dtype = torch .half , device = device )
43
- B = (448.0 * (2 * torch .rand (n , k , device = device ) - 1 )).to (dtype )
51
+ B = (dtype_max * (2 * torch .rand (n , k , device = device ) - 1 )).to (dtype )
44
52
B_scale = torch .randn (
45
53
(n // block_size , k // block_size ), dtype = torch .half , device = device
46
54
)
@@ -89,8 +97,15 @@ def benchmark_precision(
89
97
W_q , W_s = fp8_blockwise_weight_quant (W , block_size , dtype )
90
98
output_blockwise = blockwise_fp8_gemm (A_q , A_s , W_q , W_s )
91
99
92
- quantize_ (lin , int8_dynamic_activation_int4_weight ())
93
- output_rowwise = lin (A )
100
+ qact = _int8_symm_per_token_reduced_range_quant_cutlass (A )
101
+ qweight = _int4_symm_per_token_quant_cutlass (W )
102
+ output_rowwise = rowwise_scaled_linear_cutlass_s8s4 (
103
+ qact .tensor_impl .int_data ,
104
+ qact .tensor_impl .scale ,
105
+ qweight .tensor_impl .int_data ,
106
+ qweight .tensor_impl .scale ,
107
+ None ,
108
+ )
94
109
95
110
return {
96
111
"m" : m ,
0 commit comments