|
1 | 1 | import pytest
|
2 | 2 | import torch
|
3 | 3 |
|
4 |
| -from torchao.float8.float8_utils import _round_down_to_power_of_2 |
| 4 | +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 |
5 | 5 |
|
6 | 6 |
|
| 7 | +# source for notable single-precision cases: |
| 8 | +# https://en.wikipedia.org/wiki/Single-precision_floating-point_format |
| 9 | +# |
| 10 | +# TODO(danielvegamyhre): add case for largest normal fp32 value: 2**127 * (2 - 2**-23) |
| 11 | +# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value. |
7 | 12 | @pytest.mark.parametrize(
|
8 |
| - "input_shape", |
| 13 | + "input", |
9 | 14 | [
|
10 |
| - (1,), |
11 |
| - (2, 3), |
12 |
| - (8, 2048, 4, 1024), |
| 15 | + 1.0, |
| 16 | + # smallest positive subnormal number |
| 17 | + 2**-126 * 2**-23, |
| 18 | + # largest subnormal number |
| 19 | + 2**-126 * (1 - 2**-23), |
| 20 | + # smallest positive normal number |
| 21 | + 2**-126, |
| 22 | + # largest number less than one |
| 23 | + 1.0 - 2**-24, |
| 24 | + # smallest number larger than one |
| 25 | + 1.0 + 2**-23, |
| 26 | + torch.tensor([float("inf")]), |
13 | 27 | ],
|
14 | 28 | )
|
| 29 | +def test_round_scale_down_to_power_of_2_valid_inputs(input: float): |
| 30 | + input_tensor = torch.tensor(input, dtype=torch.float32) |
| 31 | + result = _round_scale_down_to_power_of_2(input_tensor) |
| 32 | + |
| 33 | + # get expected value for comparison |
| 34 | + # TODO(danielvegamyhre): support subnormal values |
| 35 | + expected_result = torch.exp2(torch.floor(torch.log2(input_tensor))) |
| 36 | + smallest_normal_fp32_value = torch.tensor([2**-126], dtype=torch.float32) |
| 37 | + expected_result = torch.max(expected_result, smallest_normal_fp32_value) |
| 38 | + |
| 39 | + assert torch.equal( |
| 40 | + result, expected_result |
| 41 | + ), f"input: {input_tensor}, expected {expected_result}, but got {result}" |
| 42 | + |
| 43 | + |
15 | 44 | @pytest.mark.parametrize(
|
16 |
| - "multiplier", |
| 45 | + "invalid_input", |
17 | 46 | [
|
18 |
| - 1.0, |
19 |
| - 2.5, |
20 |
| - 10.0, |
| 47 | + torch.tensor([0.0]), |
| 48 | + torch.tensor([-1.0]), |
| 49 | + torch.tensor([float("nan")]), |
| 50 | + torch.tensor([-float("inf")]), |
21 | 51 | ],
|
22 | 52 | )
|
23 |
| -def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int): |
24 |
| - input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier |
25 |
| - expected_output = torch.exp2(torch.floor(torch.log2(input_tensor))) |
26 |
| - result = _round_down_to_power_of_2(input_tensor) |
27 |
| - assert torch.equal( |
28 |
| - result, expected_output |
29 |
| - ), f"expected {expected_output}, but got {result}" |
| 53 | +def test_round_scale_down_to_power_of_2_invalid_inputs(invalid_input: torch.Tensor): |
| 54 | + with pytest.raises(AssertionError, match="scale must be positive fp32 value"): |
| 55 | + _round_scale_down_to_power_of_2(invalid_input) |
30 | 56 |
|
31 | 57 |
|
32 |
| -def test_non_float32_input(): |
33 |
| - non_float32_tensor = torch.tensor([3.0], dtype=torch.float64) |
34 |
| - with pytest.raises(AssertionError, match="input must be float32 tensor"): |
35 |
| - _round_down_to_power_of_2(non_float32_tensor) |
| 58 | +@pytest.mark.parametrize( |
| 59 | + "invalid_dtype", |
| 60 | + [ |
| 61 | + torch.bfloat16, |
| 62 | + torch.float16, |
| 63 | + torch.float64, |
| 64 | + torch.int8, |
| 65 | + torch.uint8, |
| 66 | + torch.int32, |
| 67 | + torch.uint32, |
| 68 | + torch.int64, |
| 69 | + ], |
| 70 | +) |
| 71 | +def test_non_float32_input(invalid_dtype: torch.dtype): |
| 72 | + non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype) |
| 73 | + with pytest.raises(AssertionError, match="scale must be float32 tensor"): |
| 74 | + _round_scale_down_to_power_of_2(non_float32_tensor) |
0 commit comments