From c70ad6091e85d91bfca3d9f9a85ee9a1ee940edf Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 14:01:42 -0800 Subject: [PATCH] add docstring to tensor_to_scale --- torchao/float8/float8_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index a410751d0..8661b89b0 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -120,7 +120,7 @@ def tensor_to_amax( @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, + hp_tensor: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, device_mesh=None, @@ -128,8 +128,19 @@ def tensor_to_scale( axiswise_dim: Optional[int] = None, power_of_2_scale: bool = False, ) -> torch.Tensor: + """ + Compute scaling factor for the given high precision tensor. + + Args: + hp_tensor: high precision tensor + float8_dtype: the float8 dtype to use + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across + power_of_2_scale: if true, round scaling factor down to the nearest power of 2. + """ amax = tensor_to_amax( - x, + hp_tensor, reduce_amax, device_mesh, scaling_granularity,