Skip to content

Commit a9fe17e

Browse files
fix linter issues
1 parent f2433b1 commit a9fe17e

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

test/float8/test_base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import torch.nn as nn
1616

1717
from torchao.utils import (
18+
TORCH_VERSION_AT_LEAST_2_5,
1819
is_sm_at_least_89,
1920
is_sm_at_least_90,
20-
TORCH_VERSION_AT_LEAST_2_5,
2121
)
2222

2323
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -26,13 +26,13 @@
2626

2727
from torchao.float8.config import (
2828
CastConfig,
29-
e4m3_dtype,
30-
e5m2_dtype,
3129
Float8LinearConfig,
3230
Float8LinearRecipeName,
33-
recipe_name_to_linear_config,
3431
ScalingGranularity,
3532
ScalingType,
33+
e4m3_dtype,
34+
e5m2_dtype,
35+
recipe_name_to_linear_config,
3636
)
3737
from torchao.float8.float8_linear import Float8Linear
3838
from torchao.float8.float8_linear_utils import (
@@ -48,15 +48,15 @@
4848
from torchao.float8.float8_tensor import (
4949
Float8Tensor,
5050
GemmInputRole,
51-
hp_tensor_and_scale_to_float8,
5251
LinearMMConfig,
5352
ScaledMMConfig,
53+
hp_tensor_and_scale_to_float8,
5454
)
5555
from torchao.float8.float8_utils import (
56+
FP8_TYPES,
5657
compute_error,
5758
config_has_stateful_scaling,
5859
fp8_tensor_statistics,
59-
FP8_TYPES,
6060
tensor_to_scale,
6161
)
6262
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear

test/float8/test_compile.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import pytest
1414

1515
from torchao.utils import (
16+
TORCH_VERSION_AT_LEAST_2_5,
1617
is_sm_at_least_89,
1718
is_sm_at_least_90,
18-
TORCH_VERSION_AT_LEAST_2_5,
1919
)
2020

2121
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -29,11 +29,11 @@
2929
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
3030
from torchao.float8.config import (
3131
CastConfig,
32-
e4m3_dtype,
3332
Float8LinearConfig,
3433
Float8LinearRecipeName,
35-
recipe_name_to_linear_config,
3634
ScalingType,
35+
e4m3_dtype,
36+
recipe_name_to_linear_config,
3737
)
3838
from torchao.float8.float8_linear import Float8Linear
3939
from torchao.float8.float8_linear_utils import (
@@ -430,6 +430,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
430430
hp_tensor2 = hp_tensor1.detach().clone()
431431
float8_config = Float8LinearConfig(
432432
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
433+
power_of_2_scale=power_of_2_scale,
433434
)
434435
linear_mm_config = LinearMMConfig(
435436
# output
@@ -459,15 +460,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
459460
e4m3_dtype,
460461
linear_mm_config,
461462
gemm_input_role=GemmInputRole.WEIGHT,
462-
power_of_2_scale=power_of_2_scale,
463+
power_of_2_scale=float8_config.power_of_2_scale,
463464
)
464465
torch._dynamo.reset()
465466
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
466467
hp_tensor2,
467468
e4m3_dtype,
468469
linear_mm_config,
469470
gemm_input_role=GemmInputRole.WEIGHT,
470-
power_of_2_scale=power_of_2_scale,
471+
power_of_2_scale=float8_config.power_of_2_scale,
471472
)
472473
assert torch.equal(float8_eager._scale, float8_compile._scale)
473474
assert torch.equal(float8_eager._data, float8_compile._data)
@@ -479,7 +480,8 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
479480
)
480481
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
481482
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
482-
from torch._inductor import config as inductor_config, metrics
483+
from torch._inductor import config as inductor_config
484+
from torch._inductor import metrics
483485

484486
inductor_config.loop_ordering_after_fusion = True
485487

0 commit comments

Comments
 (0)