Skip to content

Commit 40166e1

Browse files
convert to fp32 before rounding scale down to power of 2; update unit tests
1 parent c434498 commit 40166e1

File tree

2 files changed

+21
-55
lines changed

2 files changed

+21
-55
lines changed

test/float8/test_float8_utils.py

+18-35
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,30 @@
66

77
# source for notable single-precision cases:
88
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
9-
#
10-
# TODO(danielvegamyhre):
11-
# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23).
12-
# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value.
13-
# 2. add case for "nan"
14-
# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf.
15-
# 3. adjust cases for subnormal values so we aren't clamping the expected results
16-
# into the normal range.
17-
# preliminary investigation shows it may not be possible to support all subnormals
18-
# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x)))
19-
# approach.
209
@pytest.mark.parametrize(
21-
"input",
10+
"test_case",
2211
[
23-
1.0,
24-
float("inf"),
25-
# smallest positive subnormal number
26-
2**-126 * 2**-23,
27-
# largest subnormal number
28-
2**-126 * (1 - 2**-23),
29-
# smallest positive normal number
30-
2**-126,
31-
# largest number less than one
32-
1.0 - 2**-24,
33-
# smallest number larger than one
34-
1.0 + 2**-23,
12+
# "test_case_name": [input, expected result]
13+
("one", [1.0, 1.0]),
14+
("inf", [float("inf"), float("inf")]),
15+
("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]),
16+
("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
17+
("largest normal number", [2**127 * (2 - 2**-23), float("inf")]),
18+
("smallest positive normal number", [2**-126, 2**-126]),
19+
("largest number less than one", [1.0 - 2**-24, 0.5]),
20+
("smallest number larger than one", [1.0 + 2**-23, 1.0]),
3521
],
3622
)
37-
def test_round_scale_down_to_power_of_2_valid_inputs(input: float):
38-
input_tensor = torch.tensor(input, dtype=torch.float32)
39-
result = _round_scale_down_to_power_of_2(input_tensor)
40-
41-
# get expected value for comparison
42-
# TODO(danielvegamyhre): support subnormal values
43-
expected_result = torch.exp2(torch.floor(torch.log2(input_tensor)))
44-
smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32)
45-
expected_result = torch.max(expected_result, smallest_normal_fp32_value)
23+
def test_round_scale_down_to_power_of_2_valid_inputs(
24+
test_case: dict,
25+
):
26+
test_case_name, (input, expected_result) = test_case
27+
input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result)
4628

29+
result = _round_scale_down_to_power_of_2(input_tensor)
4730
assert torch.equal(
48-
result, expected_result
49-
), f"input: {input_tensor}, expected {expected_result}, but got {result}"
31+
result, expected_tensor
32+
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"
5033

5134

5235
@pytest.mark.parametrize(

torchao/float8/float8_utils.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -285,23 +285,6 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
285285
)
286286

287287

288-
def _round_scale_down_to_power_of_2(x: torch.Tensor):
289-
assert x.dtype == torch.float32, "scale must be float32 tensor"
290-
291-
# eps = smallest normal fp32 value
292-
# TODO(danielvegamyhre): support subnormal values
293-
eps = 2**-126
294-
x = torch.clamp(
295-
x,
296-
min=eps,
297-
)
298-
299-
# view as int32 to allow bitshifting
300-
x_int = x.view(torch.int32)
301-
302-
# clear mantissa bits (rightmost 23 bits)
303-
x_int = (x_int >> 23) << 23
304-
305-
# return result as fp32
306-
result = x_int.view(torch.float32)
307-
return result
288+
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
289+
assert scale.dtype == torch.float32, "scale must be float32 tensor"
290+
return torch.exp2(torch.floor(torch.log2(scale)))

0 commit comments

Comments
 (0)