Skip to content

Commit 4169927

Browse files
add unit tests for rounding scale down to nearest power of 2
1 parent 56132a3 commit 4169927

File tree

2 files changed

+76
-30
lines changed

2 files changed

+76
-30
lines changed

test/float8/test_utils.py

+55-22
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,68 @@
11
import pytest
22
import torch
33

4-
from torchao.float8.float8_utils import _round_down_to_power_of_2
4+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
55

66

7+
# source for notable single-precision cases:
8+
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
9+
#
10+
# TODO(danielvegamyhre):
11+
# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23).
12+
# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value.
13+
# 2. add case for "nan"
14+
# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf.
15+
# 3. adjust cases for subnormal values so we aren't clamping the expected results
16+
# into the normal range.
17+
# preliminary investigation shows it may not be possible to support all subnormals
18+
# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x)))
19+
# approach.
720
@pytest.mark.parametrize(
8-
"input_shape",
9-
[
10-
(1,),
11-
(2, 3),
12-
(8, 2048, 4, 1024),
13-
],
14-
)
15-
@pytest.mark.parametrize(
16-
"multiplier",
21+
"input",
1722
[
1823
1.0,
19-
2.5,
20-
10.0,
24+
float("inf"),
25+
# smallest positive subnormal number
26+
2**-126 * 2**-23,
27+
# largest subnormal number
28+
2**-126 * (1 - 2**-23),
29+
# smallest positive normal number
30+
2**-126,
31+
# largest number less than one
32+
1.0 - 2**-24,
33+
# smallest number larger than one
34+
1.0 + 2**-23,
2135
],
2236
)
23-
def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int):
24-
input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier
25-
expected_output = torch.exp2(torch.floor(torch.log2(input_tensor)))
26-
result = _round_down_to_power_of_2(input_tensor)
37+
def test_round_scale_down_to_power_of_2_valid_inputs(input: float):
38+
input_tensor = torch.tensor(input, dtype=torch.float32)
39+
result = _round_scale_down_to_power_of_2(input_tensor)
40+
41+
# get expected value for comparison
42+
# TODO(danielvegamyhre): support subnormal values
43+
expected_result = torch.exp2(torch.floor(torch.log2(input_tensor)))
44+
smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32)
45+
expected_result = torch.max(expected_result, smallest_normal_fp32_value)
46+
2747
assert torch.equal(
28-
result, expected_output
29-
), f"expected {expected_output}, but got {result}"
48+
result, expected_result
49+
), f"input: {input_tensor}, expected {expected_result}, but got {result}"
3050

3151

32-
def test_non_float32_input():
33-
non_float32_tensor = torch.tensor([3.0], dtype=torch.float64)
34-
with pytest.raises(AssertionError, match="input must be float32 tensor"):
35-
_round_down_to_power_of_2(non_float32_tensor)
52+
@pytest.mark.parametrize(
53+
"invalid_dtype",
54+
[
55+
torch.bfloat16,
56+
torch.float16,
57+
torch.float64,
58+
torch.int8,
59+
torch.uint8,
60+
torch.int32,
61+
torch.uint32,
62+
torch.int64,
63+
],
64+
)
65+
def test_non_float32_input(invalid_dtype: torch.dtype):
66+
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
67+
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
68+
_round_scale_down_to_power_of_2(non_float32_tensor)

torchao/float8/float8_utils.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def amax_to_scale(
4949
else:
5050
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
5151
if round_scales_to_power_of_2:
52-
res = _round_down_to_power_of_2(res)
52+
res = _round_scale_down_to_power_of_2(res)
5353
return res
5454

5555

@@ -285,10 +285,23 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
285285
)
286286

287287

288-
def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor:
289-
assert x.dtype == torch.float32, "input must be float32 tensor"
290-
# rounds down to the nearest power of 2
291-
x = x.view(torch.int32)
292-
x = (x >> 23) << 23
293-
x = x.view(torch.float32)
294-
return x
288+
def _round_scale_down_to_power_of_2(x: torch.Tensor):
289+
assert x.dtype == torch.float32, "scale must be float32 tensor"
290+
291+
# eps = smallest normal fp32 value
292+
# TODO(danielvegamyhre): support subnormal values
293+
eps = 2**-126
294+
x = torch.clamp(
295+
x,
296+
min=eps,
297+
)
298+
299+
# view as int32 to allow bitshifting
300+
x_int = x.view(torch.int32)
301+
302+
# clear mantissa bits (rightmost 23 bits)
303+
x_int = (x_int >> 23) << 23
304+
305+
# return result as fp32
306+
result = x_int.view(torch.float32)
307+
return result

0 commit comments

Comments
 (0)