Skip to content

Commit a52590f

Browse files
committed
lint
1 parent 64c58b2 commit a52590f

File tree

7 files changed

+11
-7
lines changed

7 files changed

+11
-7
lines changed

test/float8/test_base.py

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2626

2727

28-
2928
from torchao.float8.config import (
3029
CastConfig,
3130
Float8LinearConfig,

test/integration/test_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@
8080
benchmark_model,
8181
is_fbcode,
8282
is_sm_at_least_90,
83-
unwrap_tensor_subclass,
8483
skip_if_rocm,
84+
unwrap_tensor_subclass,
8585
)
8686

8787
try:

test/kernel/test_galore_downproj.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
1313
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
1414
from torchao.utils import skip_if_rocm
15+
1516
torch.manual_seed(0)
1617

1718
matmul_tuner_topk(10)

test/prototype/test_awq.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import torch
66

77
from torchao.quantization import quantize_
8-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
8+
from torchao.utils import (
9+
TORCH_VERSION_AT_LEAST_2_3,
10+
TORCH_VERSION_AT_LEAST_2_5,
11+
skip_if_rocm,
12+
)
913

1014
if TORCH_VERSION_AT_LEAST_2_3:
1115
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
1216

1317

14-
1518
class ToyLinearModel(torch.nn.Module):
1619
def __init__(self, m=512, n=256, k=128):
1720
super().__init__()

test/quantization/test_galore_quant.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
triton_quantize_blockwise,
2020
)
2121
from torchao.utils import skip_if_rocm
22+
2223
SEED = 0
2324
torch.manual_seed(SEED)
2425

test/test_utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
77

88

9-
10-
11-
129
class TestTorchVersionAtLeast(unittest.TestCase):
1310
def test_torch_version_at_least(self):
1411
test_cases = [

torchao/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from math import gcd
88
from typing import Any, Callable, Tuple
99

10+
import pytest
1011
import torch
1112
import torch.nn.utils.parametrize as parametrize
1213

@@ -160,6 +161,7 @@ def wrapper(*args, **kwargs):
160161

161162
return decorator
162163

164+
163165
def skip_if_rocm(message=None):
164166
"""Decorator to skip tests on ROCm platform with custom message.
165167
@@ -186,6 +188,7 @@ def wrapper(*args, **kwargs):
186188
return decorator(func)
187189
return decorator
188190

191+
189192
def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor:
190193
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
191194
torch.abs(output_ref)

0 commit comments

Comments
 (0)