12
12
13
13
import torch
14
14
15
- from torchao .float8 .config import ScalingGranularity
15
+ from torchao .float8 .config import Float8ScalingFactorConfig , ScalingGranularity
16
16
from torchao .float8 .distributed_utils import tensor_already_casted_to_fp8
17
17
from torchao .float8 .float8_tensor import (
18
18
Float8Tensor ,
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
36
36
device_mesh = None ,
37
37
scaling_granularity : ScalingGranularity = ScalingGranularity .TENSORWISE ,
38
38
axiswise_dim : Optional [int ] = None ,
39
+ scaling_factor_config : Float8ScalingFactorConfig = None ,
39
40
) -> Float8Tensor :
40
41
"""
41
42
Given a high precision tensor `hp_tensor`,
@@ -51,6 +52,10 @@ def hp_tensor_to_float8_dynamic(
51
52
the 3 fwd/bwd gemms of linear
52
53
scaling_granularity: Defines the scaling granularity
53
54
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
55
+ scaling_factor_config: optional configurations used to calculate the scaling factor.
56
+ * for row-wise scaling granularity, power of 2 scaling factor will be used by default,
57
+ but can be disabled via this config.
58
+ * for all other scaling granularities, power of 2 scaling factors are not used by default.
54
59
"""
55
60
scale = tensor_to_scale (
56
61
hp_tensor ,
@@ -60,6 +65,11 @@ def hp_tensor_to_float8_dynamic(
60
65
scaling_granularity ,
61
66
axiswise_dim ,
62
67
)
68
+
69
+ if _use_power_of_2_scale (scaling_granularity , scaling_factor_config ):
70
+ # this rounds the scaling factor down to the nearest power of 2.
71
+ scale = torch .exp2 (torch .floor (torch .log2 (scale )))
72
+
63
73
return hp_tensor_and_scale_to_float8 (
64
74
hp_tensor ,
65
75
scale ,
@@ -70,6 +80,36 @@ def hp_tensor_to_float8_dynamic(
70
80
)
71
81
72
82
83
+ def _use_power_of_2_scale (
84
+ scaling_granularity : ScalingGranularity ,
85
+ scaling_factor_config : Float8ScalingFactorConfig = None ,
86
+ ) -> bool :
87
+ """
88
+ Returns boolean indicating if scaling factor should be rounded down to
89
+ the nearest power of 2.
90
+
91
+ Returns true in these cases:
92
+ 1. The caller has enabled it in the scaling factor config.
93
+ 2. Default on for row-wise scaling unless user has explicitly disabled
94
+ it in the scaling factor config.
95
+
96
+ Otherwise, returns false.
97
+ """
98
+ power_of_2_scale_enabled = (
99
+ scaling_factor_config is not None
100
+ and scaling_factor_config .power_of_2_scale is True
101
+ )
102
+ power_of_2_scale_explicitly_disabled = (
103
+ scaling_factor_config is not None
104
+ and scaling_factor_config .power_of_2_scale is False
105
+ )
106
+ use_power_of_2_scale = power_of_2_scale_enabled or (
107
+ scaling_granularity == ScalingGranularity .AXISWISE
108
+ and not power_of_2_scale_explicitly_disabled
109
+ )
110
+ return use_power_of_2_scale
111
+
112
+
73
113
def hp_tensor_to_float8_delayed (
74
114
hp_tensor : torch .Tensor ,
75
115
s : torch .Tensor ,
0 commit comments