13
13
import pytest
14
14
15
15
from torchao .utils import (
16
- TORCH_VERSION_AT_LEAST_2_5 ,
17
16
is_sm_at_least_89 ,
18
17
is_sm_at_least_90 ,
18
+ TORCH_VERSION_AT_LEAST_2_5 ,
19
19
)
20
20
21
21
if not TORCH_VERSION_AT_LEAST_2_5 :
29
29
from torchao .float8 import _prototype_register_float8_delayed_scaling_inductor_passes
30
30
from torchao .float8 .config import (
31
31
CastConfig ,
32
+ e4m3_dtype ,
32
33
Float8LinearConfig ,
33
34
Float8LinearRecipeName ,
34
- ScalingType ,
35
- e4m3_dtype ,
36
35
recipe_name_to_linear_config ,
36
+ ScalingType ,
37
37
)
38
38
from torchao .float8 .float8_linear import Float8Linear
39
39
from torchao .float8 .float8_linear_utils import (
45
45
hp_tensor_to_float8_delayed ,
46
46
hp_tensor_to_float8_dynamic ,
47
47
)
48
- from torchao .float8 .float8_tensor import (
49
- GemmInputRole ,
50
- LinearMMConfig ,
51
- ScaledMMConfig ,
52
- )
48
+ from torchao .float8 .float8_tensor import GemmInputRole , LinearMMConfig , ScaledMMConfig
53
49
from torchao .float8 .float8_utils import config_has_stateful_scaling
54
50
from torchao .float8 .stateful_float8_linear import StatefulFloat8Linear
55
51
from torchao .testing .float8 .test_utils import get_test_float8_linear_config
@@ -420,7 +416,14 @@ def test_sync_amax_func_cuda_graph_success():
420
416
torch .float16 ,
421
417
],
422
418
)
423
- def test_dynamic_scale_numeric_parity (dtype : torch .dtype ):
419
+ @pytest .mark .parametrize (
420
+ "power_of_2_scale" ,
421
+ [
422
+ True ,
423
+ False ,
424
+ ],
425
+ )
426
+ def test_dynamic_scale_numeric_parity (dtype : torch .dtype , power_of_2_scale : bool ):
424
427
scaling_type_weight = ScalingType .DYNAMIC
425
428
torch .manual_seed (42 )
426
429
hp_tensor1 = torch .randn (16 , 16 , device = "cuda" , dtype = dtype )
@@ -456,13 +459,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
456
459
e4m3_dtype ,
457
460
linear_mm_config ,
458
461
gemm_input_role = GemmInputRole .WEIGHT ,
462
+ power_of_2_scale = power_of_2_scale ,
459
463
)
460
464
torch ._dynamo .reset ()
461
465
float8_compile = torch .compile (hp_tensor_to_float8_dynamic )(
462
466
hp_tensor2 ,
463
467
e4m3_dtype ,
464
468
linear_mm_config ,
465
469
gemm_input_role = GemmInputRole .WEIGHT ,
470
+ power_of_2_scale = power_of_2_scale ,
466
471
)
467
472
assert torch .equal (float8_eager ._scale , float8_compile ._scale )
468
473
assert torch .equal (float8_eager ._data , float8_compile ._data )
@@ -474,8 +479,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
474
479
)
475
480
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
476
481
def test_delayed_scaling_pattern_replacement (dtype : torch .dtype ):
477
- from torch ._inductor import config as inductor_config
478
- from torch ._inductor import metrics
482
+ from torch ._inductor import config as inductor_config , metrics
479
483
480
484
inductor_config .loop_ordering_after_fusion = True
481
485
0 commit comments