|
10 | 10 | import torch.distributed as dist
|
11 | 11 | from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce
|
12 | 12 |
|
13 |
| -from torchao.float8.config import ( |
14 |
| - Float8LinearConfig, |
15 |
| - ScalingGranularity, |
16 |
| - ScalingType, |
17 |
| -) |
| 13 | +from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType |
18 | 14 |
|
19 | 15 | # Helpful visualizer for debugging (only supports fp32):
|
20 | 16 | # https://www.h-schmidt.net/FloatConverter/IEEE754.html
|
|
33 | 29 |
|
34 | 30 |
|
35 | 31 | @torch.no_grad()
|
36 |
| -def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): |
| 32 | +def amax_to_scale( |
| 33 | + amax: torch.Tensor, |
| 34 | + float8_dtype: torch.dtype, |
| 35 | + round_scales_to_power_of_2: bool = False, |
| 36 | +): |
37 | 37 | """Converts the amax value of a tensor to the fp8 scale.
|
38 | 38 | Args:
|
39 | 39 | amax: The amax value of the tensor.
|
40 | 40 | float8_dtype: The float8 dtype.
|
| 41 | + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. |
41 | 42 | """
|
42 | 43 | # torch.compile and eager show different numerics for 1.0 / float32,
|
43 | 44 | # upcast to float64 to ensure same numeric between compile and eager
|
44 | 45 | amax = amax.to(torch.float64)
|
45 | 46 | if float8_dtype in FP8_TYPES:
|
46 | 47 | res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
|
| 48 | + res = res.to(torch.float32) |
47 | 49 | else:
|
48 | 50 | raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
49 |
| - |
50 |
| - return res.to(torch.float32) |
| 51 | + if round_scales_to_power_of_2: |
| 52 | + res = _round_scale_down_to_power_of_2(res) |
| 53 | + return res |
51 | 54 |
|
52 | 55 |
|
53 | 56 | @torch.no_grad()
|
@@ -119,21 +122,35 @@ def tensor_to_amax(
|
119 | 122 |
|
120 | 123 | @torch.no_grad()
|
121 | 124 | def tensor_to_scale(
|
122 |
| - x: torch.Tensor, |
| 125 | + hp_tensor: torch.Tensor, |
123 | 126 | float8_dtype: torch.dtype,
|
124 | 127 | reduce_amax: bool = False,
|
125 | 128 | device_mesh=None,
|
126 | 129 | scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
|
127 | 130 | axiswise_dim: Optional[int] = None,
|
| 131 | + round_scales_to_power_of_2: bool = False, |
128 | 132 | ) -> torch.Tensor:
|
| 133 | + """ |
| 134 | + Compute scaling factor for the given high precision tensor. |
| 135 | +
|
| 136 | + Args: |
| 137 | + hp_tensor: high precision tensor |
| 138 | + float8_dtype: the float8 dtype to use |
| 139 | + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks |
| 140 | + scaling_granularity: Defines the scaling granularity |
| 141 | + axiswise_dim: if axiswise granularity is used, defines the dim to scale across |
| 142 | + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. |
| 143 | + """ |
129 | 144 | amax = tensor_to_amax(
|
130 |
| - x, |
| 145 | + hp_tensor, |
131 | 146 | reduce_amax,
|
132 | 147 | device_mesh,
|
133 | 148 | scaling_granularity,
|
134 | 149 | axiswise_dim,
|
135 | 150 | )
|
136 |
| - return amax_to_scale(amax, float8_dtype) |
| 151 | + return amax_to_scale( |
| 152 | + amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2 |
| 153 | + ) |
137 | 154 |
|
138 | 155 |
|
139 | 156 | def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
|
@@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
|
266 | 283 | or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC
|
267 | 284 | or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
|
268 | 285 | )
|
| 286 | + |
| 287 | + |
| 288 | +def _round_scale_down_to_power_of_2(scale: torch.Tensor): |
| 289 | + assert scale.dtype == torch.float32, "scale must be float32 tensor" |
| 290 | + return torch.exp2(torch.floor(torch.log2(scale))) |
0 commit comments