From 92872d955aecf14a33ed7c4f8076782af6e97b67 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 08:26:49 -0800 Subject: [PATCH] add tests for round to power of 2 --- test/float8/test_utils.py | 26 ++++++++++++++++++++++++++ torchao/float8/float8_utils.py | 14 ++++++++++---- 2 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 test/float8/test_utils.py diff --git a/test/float8/test_utils.py b/test/float8/test_utils.py new file mode 100644 index 0000000000..f8c3b6c170 --- /dev/null +++ b/test/float8/test_utils.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from torchao.float8.float8_utils import _round_down_to_power_of_2 + + +@pytest.mark.parametrize( + "input_tensor", + [ + # uniform(0,1) in shape (2,3) * rand_int in [0,1024] + torch.rand(2, 3, dtype=torch.float32) * torch.randint(0, 1024, (1,)) + for _ in range(10) + ], +) +def test_round_down_to_power_of_2(input_tensor): + expected_output = torch.exp2(torch.floor(torch.log2(input_tensor))) + result = _round_down_to_power_of_2(input_tensor) + assert torch.allclose( + result, expected_output + ), f"expected {expected_output}, but got {result}" + + +def test_non_float32_input(): + non_float32_tensor = torch.tensor([3.0], dtype=torch.float64) + with pytest.raises(AssertionError, match="input must be float32 tensor"): + _round_down_to_power_of_2(non_float32_tensor) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index bdb08bbb01..a7002516b8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -49,10 +49,7 @@ def amax_to_scale( else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") if round_scales_to_power_of_2: - # rounds down to the nearest power of 2 - res = res.view(torch.int32) - res = (res >> 23) << 23 - res = res.view(torch.float32) + res = _round_down_to_power_of_2(res) return res @@ -286,3 +283,12 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC ) + + +def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float32, "input must be float32 tensor" + # rounds down to the nearest power of 2 + x = x.view(torch.int32) + x = (x >> 23) << 23 + x = x.view(torch.float32) + return x