Skip to content

Commit 56132a3

Browse files
add tests for round to power of 2
1 parent ab93e18 commit 56132a3

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

test/float8/test_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
import torch
3+
4+
from torchao.float8.float8_utils import _round_down_to_power_of_2
5+
6+
7+
@pytest.mark.parametrize(
8+
"input_shape",
9+
[
10+
(1,),
11+
(2, 3),
12+
(8, 2048, 4, 1024),
13+
],
14+
)
15+
@pytest.mark.parametrize(
16+
"multiplier",
17+
[
18+
1.0,
19+
2.5,
20+
10.0,
21+
],
22+
)
23+
def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int):
24+
input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier
25+
expected_output = torch.exp2(torch.floor(torch.log2(input_tensor)))
26+
result = _round_down_to_power_of_2(input_tensor)
27+
assert torch.equal(
28+
result, expected_output
29+
), f"expected {expected_output}, but got {result}"
30+
31+
32+
def test_non_float32_input():
33+
non_float32_tensor = torch.tensor([3.0], dtype=torch.float64)
34+
with pytest.raises(AssertionError, match="input must be float32 tensor"):
35+
_round_down_to_power_of_2(non_float32_tensor)

torchao/float8/float8_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ def amax_to_scale(
4949
else:
5050
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
5151
if round_scales_to_power_of_2:
52-
# rounds down to the nearest power of 2
53-
res = res.view(torch.int32)
54-
res = (res >> 23) << 23
55-
res = res.view(torch.float32)
52+
res = _round_down_to_power_of_2(res)
5653
return res
5754

5855

@@ -286,3 +283,12 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
286283
or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC
287284
or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
288285
)
286+
287+
288+
def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor:
289+
assert x.dtype == torch.float32, "input must be float32 tensor"
290+
# rounds down to the nearest power of 2
291+
x = x.view(torch.int32)
292+
x = (x >> 23) << 23
293+
x = x.view(torch.float32)
294+
return x

0 commit comments

Comments
 (0)