File tree Expand file tree Collapse file tree 1 file changed +14
-3
lines changed
Expand file tree Collapse file tree 1 file changed +14
-3
lines changed Original file line number Diff line number Diff line change 88
99import torch
1010import torch .distributed as dist
11- from torch .distributed ._functional_collectives import AsyncCollectiveTensor , all_reduce
11+ from torch .distributed ._functional_collectives import all_reduce , AsyncCollectiveTensor
1212
1313from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
1414
@@ -120,16 +120,27 @@ def tensor_to_amax(
120120
121121@torch .no_grad ()
122122def tensor_to_scale (
123- x : torch .Tensor ,
123+ hp_tensor : torch .Tensor ,
124124 float8_dtype : torch .dtype ,
125125 reduce_amax : bool = False ,
126126 device_mesh = None ,
127127 scaling_granularity : ScalingGranularity = ScalingGranularity .TENSORWISE ,
128128 axiswise_dim : Optional [int ] = None ,
129129 power_of_2_scale : bool = False ,
130130) -> torch .Tensor :
131+ """
132+ Compute scaling factor for the given high precision tensor.
133+
134+ Args:
135+ hp_tensor: the tensor to convert
136+ float8_dtype: the float8 dtype to use
137+ reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
138+ scaling_granularity: Defines the scaling granularity
139+ axiswise_dim: if axiswise granularity is used, defines the dim to scale across
140+ power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
141+ """
131142 amax = tensor_to_amax (
132- x ,
143+ hp_tensor ,
133144 reduce_amax ,
134145 device_mesh ,
135146 scaling_granularity ,
You can’t perform that action at this time.
0 commit comments