Skip to content

Commit

Permalink
add unit tests for rounding scale down to nearest power of 2
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 6, 2025
1 parent 56132a3 commit 5e0a199
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 28 deletions.
79 changes: 59 additions & 20 deletions test/float8/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,74 @@
import pytest
import torch

from torchao.float8.float8_utils import _round_down_to_power_of_2
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2


# source for notable single-precision cases:
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
#
# TODO(danielvegamyhre): add case for largest normal fp32 value: 2**127 * (2 - 2**-23)
# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value.
@pytest.mark.parametrize(
"input_shape",
"input",
[
(1,),
(2, 3),
(8, 2048, 4, 1024),
1.0,
# smallest positive subnormal number
2**-126 * 2**-23,
# largest subnormal number
2**-126 * (1 - 2**-23),
# smallest positive normal number
2**-126,
# largest number less than one
1.0 - 2**-24,
# smallest number larger than one
1.0 + 2**-23,
torch.tensor([float("inf")]),
],
)
def test_round_scale_down_to_power_of_2_valid_inputs(input: float):
input_tensor = torch.tensor(input, dtype=torch.float32)
result = _round_scale_down_to_power_of_2(input_tensor)

# get expected value for comparison
# TODO(danielvegamyhre): support subnormal values
expected_result = torch.exp2(torch.floor(torch.log2(input_tensor)))
smallest_normal_fp32_value = torch.tensor([2**-126], dtype=torch.float32)
expected_result = torch.max(expected_result, smallest_normal_fp32_value)

assert torch.equal(
result, expected_result
), f"input: {input_tensor}, expected {expected_result}, but got {result}"


@pytest.mark.parametrize(
"multiplier",
"invalid_input",
[
1.0,
2.5,
10.0,
torch.tensor([0.0]),
torch.tensor([-1.0]),
torch.tensor([float("nan")]),
torch.tensor([-float("inf")]),
],
)
def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int):
input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier
expected_output = torch.exp2(torch.floor(torch.log2(input_tensor)))
result = _round_down_to_power_of_2(input_tensor)
assert torch.equal(
result, expected_output
), f"expected {expected_output}, but got {result}"
def test_round_scale_down_to_power_of_2_invalid_inputs(invalid_input: torch.Tensor):
with pytest.raises(AssertionError, match="scale must be positive fp32 value"):
_round_scale_down_to_power_of_2(invalid_input)


def test_non_float32_input():
non_float32_tensor = torch.tensor([3.0], dtype=torch.float64)
with pytest.raises(AssertionError, match="input must be float32 tensor"):
_round_down_to_power_of_2(non_float32_tensor)
@pytest.mark.parametrize(
"invalid_dtype",
[
torch.bfloat16,
torch.float16,
torch.float64,
torch.int8,
torch.uint8,
torch.int32,
torch.uint32,
torch.int64,
],
)
def test_non_float32_input(invalid_dtype: torch.dtype):
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
_round_scale_down_to_power_of_2(non_float32_tensor)
29 changes: 21 additions & 8 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def amax_to_scale(
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
if round_scales_to_power_of_2:
res = _round_down_to_power_of_2(res)
res = _round_scale_down_to_power_of_2(res)
return res


Expand Down Expand Up @@ -285,10 +285,23 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
)


def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.float32, "input must be float32 tensor"
# rounds down to the nearest power of 2
x = x.view(torch.int32)
x = (x >> 23) << 23
x = x.view(torch.float32)
return x
def _round_scale_down_to_power_of_2(x: torch.Tensor):
assert x.dtype == torch.float32, "scale must be float32 tensor"
assert torch.all(x > 0), "scale must be positive fp32 value"

# eps = smallest normal fp32 value
eps = torch.tensor([2**-126])
x = torch.clamp(
x,
min=eps,
)

# view as int32 to allow bitshifting
x_int = x.view(torch.int32)

# clear mantissa bits (rightmost 23 bits)
x_int = (x_int >> 23) << 23

# return result as fp32
result = x_int.view(torch.float32)
return result

0 comments on commit 5e0a199

Please sign in to comment.