13
13
import pytest
14
14
15
15
from torchao .utils import (
16
+ TORCH_VERSION_AT_LEAST_2_5 ,
16
17
is_sm_at_least_89 ,
17
18
is_sm_at_least_90 ,
18
- TORCH_VERSION_AT_LEAST_2_5 ,
19
19
)
20
20
21
21
if not TORCH_VERSION_AT_LEAST_2_5 :
29
29
from torchao .float8 import _prototype_register_float8_delayed_scaling_inductor_passes
30
30
from torchao .float8 .config import (
31
31
CastConfig ,
32
- e4m3_dtype ,
33
32
Float8LinearConfig ,
34
33
Float8LinearRecipeName ,
35
- recipe_name_to_linear_config ,
36
34
ScalingType ,
35
+ e4m3_dtype ,
36
+ recipe_name_to_linear_config ,
37
37
)
38
38
from torchao .float8 .float8_linear import Float8Linear
39
39
from torchao .float8 .float8_linear_utils import (
@@ -430,6 +430,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
430
430
hp_tensor2 = hp_tensor1 .detach ().clone ()
431
431
float8_config = Float8LinearConfig (
432
432
cast_config_weight = CastConfig (scaling_type = scaling_type_weight ),
433
+ power_of_2_scale = power_of_2_scale ,
433
434
)
434
435
linear_mm_config = LinearMMConfig (
435
436
# output
@@ -459,15 +460,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
459
460
e4m3_dtype ,
460
461
linear_mm_config ,
461
462
gemm_input_role = GemmInputRole .WEIGHT ,
462
- power_of_2_scale = power_of_2_scale ,
463
+ power_of_2_scale = float8_config . power_of_2_scale ,
463
464
)
464
465
torch ._dynamo .reset ()
465
466
float8_compile = torch .compile (hp_tensor_to_float8_dynamic )(
466
467
hp_tensor2 ,
467
468
e4m3_dtype ,
468
469
linear_mm_config ,
469
470
gemm_input_role = GemmInputRole .WEIGHT ,
470
- power_of_2_scale = power_of_2_scale ,
471
+ power_of_2_scale = float8_config . power_of_2_scale ,
471
472
)
472
473
assert torch .equal (float8_eager ._scale , float8_compile ._scale )
473
474
assert torch .equal (float8_eager ._data , float8_compile ._data )
@@ -479,7 +480,8 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
479
480
)
480
481
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
481
482
def test_delayed_scaling_pattern_replacement (dtype : torch .dtype ):
482
- from torch ._inductor import config as inductor_config , metrics
483
+ from torch ._inductor import config as inductor_config
484
+ from torch ._inductor import metrics
483
485
484
486
inductor_config .loop_ordering_after_fusion = True
485
487
0 commit comments