File tree 1 file changed +14
-3
lines changed
1 file changed +14
-3
lines changed Original file line number Diff line number Diff line change 8
8
9
9
import torch
10
10
import torch .distributed as dist
11
- from torch .distributed ._functional_collectives import AsyncCollectiveTensor , all_reduce
11
+ from torch .distributed ._functional_collectives import all_reduce , AsyncCollectiveTensor
12
12
13
13
from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
14
14
@@ -120,16 +120,27 @@ def tensor_to_amax(
120
120
121
121
@torch .no_grad ()
122
122
def tensor_to_scale (
123
- x : torch .Tensor ,
123
+ hp_tensor : torch .Tensor ,
124
124
float8_dtype : torch .dtype ,
125
125
reduce_amax : bool = False ,
126
126
device_mesh = None ,
127
127
scaling_granularity : ScalingGranularity = ScalingGranularity .TENSORWISE ,
128
128
axiswise_dim : Optional [int ] = None ,
129
129
power_of_2_scale : bool = False ,
130
130
) -> torch .Tensor :
131
+ """
132
+ Compute scaling factor for the given high precision tensor.
133
+
134
+ Args:
135
+ hp_tensor: the tensor to convert
136
+ float8_dtype: the float8 dtype to use
137
+ reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
138
+ scaling_granularity: Defines the scaling granularity
139
+ axiswise_dim: if axiswise granularity is used, defines the dim to scale across
140
+ power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
141
+ """
131
142
amax = tensor_to_amax (
132
- x ,
143
+ hp_tensor ,
133
144
reduce_amax ,
134
145
device_mesh ,
135
146
scaling_granularity ,
You can’t perform that action at this time.
0 commit comments