Skip to content

Commit 067db27

Browse files
add docstring to tensor_to_scale
1 parent 896bd8f commit 067db27

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

torchao/float8/float8_utils.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
import torch.distributed as dist
11-
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce
11+
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1212

1313
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1414

@@ -120,16 +120,27 @@ def tensor_to_amax(
120120

121121
@torch.no_grad()
122122
def tensor_to_scale(
123-
x: torch.Tensor,
123+
hp_tensor: torch.Tensor,
124124
float8_dtype: torch.dtype,
125125
reduce_amax: bool = False,
126126
device_mesh=None,
127127
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
128128
axiswise_dim: Optional[int] = None,
129129
power_of_2_scale: bool = False,
130130
) -> torch.Tensor:
131+
"""
132+
Compute scaling factor for the given high precision tensor.
133+
134+
Args:
135+
hp_tensor: the tensor to convert
136+
float8_dtype: the float8 dtype to use
137+
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
138+
scaling_granularity: Defines the scaling granularity
139+
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
140+
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
141+
"""
131142
amax = tensor_to_amax(
132-
x,
143+
hp_tensor,
133144
reduce_amax,
134145
device_mesh,
135146
scaling_granularity,

0 commit comments

Comments
 (0)