Skip to content

Commit f51a142

Browse files
committed
Fix:
- condition triton import in gemm - linting
1 parent 337d264 commit f51a142

File tree

5 files changed

+10
-11
lines changed

5 files changed

+10
-11
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
fp8_blockwise_act_quant,
1313
fp8_blockwise_weight_quant,
1414
)
15-
1615
from torchao.quantization.quant_api import (
17-
_int8_symm_per_token_reduced_range_quant_cutlass,
1816
_int4_symm_per_token_quant_cutlass,
17+
_int8_symm_per_token_reduced_range_quant_cutlass,
1918
)
20-
2119
from torchao.utils import is_sm_at_least_89
2220

2321

@@ -44,7 +42,7 @@ def get_blockwise_problem(
4442
assert dtype in [
4543
torch.float8_e4m3fn,
4644
torch.float8_e5m2,
47-
], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
45+
], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
4846
dtype_max = torch.finfo(dtype).max
4947
A = (dtype_max * (2 * torch.rand(m, k, device=device) - 1)).to(dtype)
5048
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=device)

test/prototype/test_blockwise_triton.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
fp8_blockwise_weight_dequant,
88
fp8_blockwise_weight_quant,
99
)
10-
1110
from torchao.utils import is_sm_at_least_89
1211

1312
BLOCKWISE_SIZE_MNK = [

torchao/prototype/blockwise_fp8/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from .blockwise_linear import BlockwiseQuantLinear
33
from .blockwise_quantization import (
44
fp8_blockwise_act_quant,
5-
fp8_blockwise_weight_quant,
65
fp8_blockwise_weight_dequant,
6+
fp8_blockwise_weight_quant,
77
)
88

99
__all__ = [

torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
2-
import triton
3-
import triton.language as tl
4-
from triton import Config
2+
3+
if torch.cuda.is_available():
4+
import triton
5+
import triton.language as tl
6+
from triton import Config
57

68
# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
79

torchao/prototype/blockwise_fp8/blockwise_quantization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def fp8_blockwise_act_quant(
5555
assert dtype in [
5656
torch.float8_e4m3fn,
5757
torch.float8_e5m2,
58-
], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
58+
], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
5959
y = torch.empty_like(x, dtype=dtype)
6060
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
6161
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
@@ -117,7 +117,7 @@ def fp8_blockwise_weight_quant(
117117
assert dtype in [
118118
torch.float8_e4m3fn,
119119
torch.float8_e5m2,
120-
], f"dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
120+
], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2"
121121
M, N = x.size()
122122
y = torch.empty_like(x, dtype=dtype)
123123
s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32)

0 commit comments

Comments
 (0)