|
1 | 1 | import pytest
|
2 | 2 | import torch
|
3 | 3 |
|
4 |
| -from torchao.float8.float8_utils import _round_down_to_power_of_2 |
| 4 | +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 |
5 | 5 |
|
6 | 6 |
|
| 7 | +# source for notable single-precision cases: |
| 8 | +# https://en.wikipedia.org/wiki/Single-precision_floating-point_format |
| 9 | +# |
| 10 | +# TODO(danielvegamyhre): |
| 11 | +# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23). |
| 12 | +# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value. |
| 13 | +# 2. add case for "nan" |
| 14 | +# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf. |
| 15 | +# 3. adjust cases for subnormal values so we aren't clamping the expected results |
| 16 | +# into the normal range. |
| 17 | +# preliminary investigation shows it may not be possible to support all subnormals |
| 18 | +# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x))) |
| 19 | +# approach. |
7 | 20 | @pytest.mark.parametrize(
|
8 |
| - "input_shape", |
9 |
| - [ |
10 |
| - (1,), |
11 |
| - (2, 3), |
12 |
| - (8, 2048, 4, 1024), |
13 |
| - ], |
14 |
| -) |
15 |
| -@pytest.mark.parametrize( |
16 |
| - "multiplier", |
| 21 | + "input", |
17 | 22 | [
|
18 | 23 | 1.0,
|
19 |
| - 2.5, |
20 |
| - 10.0, |
| 24 | + float("inf"), |
| 25 | + # smallest positive subnormal number |
| 26 | + 2**-126 * 2**-23, |
| 27 | + # largest subnormal number |
| 28 | + 2**-126 * (1 - 2**-23), |
| 29 | + # smallest positive normal number |
| 30 | + 2**-126, |
| 31 | + # largest number less than one |
| 32 | + 1.0 - 2**-24, |
| 33 | + # smallest number larger than one |
| 34 | + 1.0 + 2**-23, |
21 | 35 | ],
|
22 | 36 | )
|
23 |
| -def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int): |
24 |
| - input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier |
25 |
| - expected_output = torch.exp2(torch.floor(torch.log2(input_tensor))) |
26 |
| - result = _round_down_to_power_of_2(input_tensor) |
| 37 | +def test_round_scale_down_to_power_of_2_valid_inputs(input: float): |
| 38 | + input_tensor = torch.tensor(input, dtype=torch.float32) |
| 39 | + result = _round_scale_down_to_power_of_2(input_tensor) |
| 40 | + |
| 41 | + # get expected value for comparison |
| 42 | + # TODO(danielvegamyhre): support subnormal values |
| 43 | + expected_result = torch.exp2(torch.floor(torch.log2(input_tensor))) |
| 44 | + smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32) |
| 45 | + expected_result = torch.max(expected_result, smallest_normal_fp32_value) |
| 46 | + |
27 | 47 | assert torch.equal(
|
28 |
| - result, expected_output |
29 |
| - ), f"expected {expected_output}, but got {result}" |
| 48 | + result, expected_result |
| 49 | + ), f"input: {input_tensor}, expected {expected_result}, but got {result}" |
30 | 50 |
|
31 | 51 |
|
32 |
| -def test_non_float32_input(): |
33 |
| - non_float32_tensor = torch.tensor([3.0], dtype=torch.float64) |
34 |
| - with pytest.raises(AssertionError, match="input must be float32 tensor"): |
35 |
| - _round_down_to_power_of_2(non_float32_tensor) |
| 52 | +@pytest.mark.parametrize( |
| 53 | + "invalid_dtype", |
| 54 | + [ |
| 55 | + torch.bfloat16, |
| 56 | + torch.float16, |
| 57 | + torch.float64, |
| 58 | + torch.int8, |
| 59 | + torch.uint8, |
| 60 | + torch.int32, |
| 61 | + torch.uint32, |
| 62 | + torch.int64, |
| 63 | + ], |
| 64 | +) |
| 65 | +def test_non_float32_input(invalid_dtype: torch.dtype): |
| 66 | + non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype) |
| 67 | + with pytest.raises(AssertionError, match="scale must be float32 tensor"): |
| 68 | + _round_scale_down_to_power_of_2(non_float32_tensor) |
0 commit comments