Skip to content

Commit 8d68d45

Browse files
Fix:
- removing triton dependency - cleanning adaptative dtype
1 parent 0aad4c3 commit 8d68d45

File tree

4 files changed

+27
-32
lines changed

4 files changed

+27
-32
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
quantize_,
1616
)
1717

18+
from torchao.utils import is_sm_at_least_89
19+
1820

1921
def benchmark_microseconds(f, *args):
2022
return do_bench(lambda: f(*args), return_mode="median") * 1e3
@@ -101,25 +103,7 @@ def benchmark_precision(
101103
}
102104

103105

104-
def get_device_available_dtypes():
105-
sm = torch.cuda.get_device_capability()
106-
available_dtypes = []
107-
108-
if sm[0] == 8 and sm[1] == 0: # A100
109-
available_dtypes.append(torch.float8_e5m2)
110-
elif sm[0] == 9 and sm[1] == 0: # H100
111-
available_dtypes.append(torch.float8_e5m2)
112-
elif sm[0] == 8 and sm[1] == 9: # L4
113-
available_dtypes.append(torch.float8_e4m3fn)
114-
available_dtypes.append(torch.float8_e5m2)
115-
116-
print(
117-
f"Available data types for device with compute capability {sm}: {available_dtypes}"
118-
)
119-
return available_dtypes
120-
121-
122-
if __name__ == "__main__":
106+
if __name__ == "__main__" and torch.cuda.is_available():
123107
device = torch.device("cuda")
124108
k_vals = (8192, 8192, 8192, 28672)
125109
n_vals = (8192, 10240, 57344, 8192)
@@ -128,7 +112,11 @@ def get_device_available_dtypes():
128112
latency_results = []
129113
precision_results = []
130114

131-
available_dtypes = get_device_available_dtypes()
115+
available_dtypes = (
116+
[torch.float8_e4m3fn, torch.float8_e5m2]
117+
if is_sm_at_least_89()
118+
else [torch.float8_e5m2]
119+
)
132120

133121
for m in tqdm([1 << i for i in range(10)]):
134122
for dtype in available_dtypes:

dev-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ tabulate # QOL for printing tables to stdout
1717
tiktoken
1818
blobfile
1919
lm_eval
20-
triton
2120
# sam
2221
diskcache
2322
pycocotools

test/prototype/test_blockwise_triton.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
fp8_blockwise_weight_quant,
99
)
1010

11+
from torchao.utils import is_sm_at_least_89
12+
1113
BLOCKWISE_SIZE_MNK = [
1214
(2, 512, 128),
1315
(3, 2048, 2048),
@@ -20,9 +22,15 @@
2022

2123
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2224
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
23-
def test_blockwise_quant_dequant(_, N, K):
25+
@pytest.mark.parametrize(
26+
"dtype",
27+
[torch.float8_e4m3fn, torch.float8_e5m2]
28+
if is_sm_at_least_89()
29+
else [torch.float8_e5m2],
30+
)
31+
def test_blockwise_quant_dequant(_, N, K, dtype):
2432
x = torch.randn(N, K).cuda()
25-
qx, s = fp8_blockwise_weight_quant(x)
33+
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
2634
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
2735
error = torch.norm(x - x_reconstructed) / torch.norm(x)
2836
print(f"Relative Error: {error.item():.6f}")
@@ -32,17 +40,19 @@ def test_blockwise_quant_dequant(_, N, K):
3240

3341
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
3442
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
35-
def test_blockwise_fp8_gemm(M, N, K):
43+
@pytest.mark.parametrize(
44+
"dtype",
45+
[torch.float8_e4m3fn, torch.float8_e5m2]
46+
if is_sm_at_least_89()
47+
else [torch.float8_e5m2],
48+
)
49+
def test_blockwise_fp8_gemm(M, N, K, dtype):
3650
A = torch.randn(M, K).cuda()
3751
B = torch.randn(N, K).cuda()
38-
3952
C = A @ B.T
40-
41-
A_q, A_s = fp8_blockwise_act_quant(A)
42-
B_q, B_s = fp8_blockwise_weight_quant(B)
43-
53+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
54+
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
4455
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
45-
print(C_q, C)
4656
error = torch.norm(C - C_q) / torch.norm(C)
4757
print(f"Relative Error: {error.item():.6f}")
4858

torchao/prototype/blockwise_fp8/blockwise_linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def __init__(
3232
super().__init__()
3333
supported_dtypes = [
3434
torch.float8_e4m3fn,
35-
torch.float8_e4m3fnuz,
3635
torch.float8_e5m2,
37-
torch.float8_e5m2fnuz,
3836
]
3937
assert (
4038
dtype in supported_dtypes

0 commit comments

Comments
 (0)