Skip to content

Commit

Permalink
add docstring to tensor_to_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 5, 2025
1 parent 896bd8f commit c70ad60
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,27 @@ def tensor_to_amax(

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor,
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
) -> torch.Tensor:
"""
Compute scaling factor for the given high precision tensor.
Args:
hp_tensor: high precision tensor
float8_dtype: the float8 dtype to use
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
"""
amax = tensor_to_amax(
x,
hp_tensor,
reduce_amax,
device_mesh,
scaling_granularity,
Expand Down

0 comments on commit c70ad60

Please sign in to comment.