|
6 | 6 |
|
7 | 7 | # source for notable single-precision cases:
|
8 | 8 | # https://en.wikipedia.org/wiki/Single-precision_floating-point_format
|
9 |
| -# |
10 |
| -# TODO(danielvegamyhre): |
11 |
| -# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23). |
12 |
| -# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value. |
13 |
| -# 2. add case for "nan" |
14 |
| -# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf. |
15 |
| -# 3. adjust cases for subnormal values so we aren't clamping the expected results |
16 |
| -# into the normal range. |
17 |
| -# preliminary investigation shows it may not be possible to support all subnormals |
18 |
| -# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x))) |
19 |
| -# approach. |
20 | 9 | @pytest.mark.parametrize(
|
21 |
| - "input", |
| 10 | + "test_case", |
22 | 11 | [
|
23 |
| - 1.0, |
24 |
| - float("inf"), |
25 |
| - # smallest positive subnormal number |
26 |
| - 2**-126 * 2**-23, |
27 |
| - # largest subnormal number |
28 |
| - 2**-126 * (1 - 2**-23), |
29 |
| - # smallest positive normal number |
30 |
| - 2**-126, |
31 |
| - # largest number less than one |
32 |
| - 1.0 - 2**-24, |
33 |
| - # smallest number larger than one |
34 |
| - 1.0 + 2**-23, |
| 12 | + # "test_case_name": [input, expected result] |
| 13 | + ("one", [1.0, 1.0]), |
| 14 | + ("inf", [float("inf"), float("inf")]), |
| 15 | + ("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]), |
| 16 | + ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), |
| 17 | + ("largest normal number", [2**127 * (2 - 2**-23), float("inf")]), |
| 18 | + ("smallest positive normal number", [2**-126, 2**-126]), |
| 19 | + ("largest number less than one", [1.0 - 2**-24, 0.5]), |
| 20 | + ("smallest number larger than one", [1.0 + 2**-23, 1.0]), |
35 | 21 | ],
|
36 | 22 | )
|
37 |
| -def test_round_scale_down_to_power_of_2_valid_inputs(input: float): |
38 |
| - input_tensor = torch.tensor(input, dtype=torch.float32) |
39 |
| - result = _round_scale_down_to_power_of_2(input_tensor) |
40 |
| - |
41 |
| - # get expected value for comparison |
42 |
| - # TODO(danielvegamyhre): support subnormal values |
43 |
| - expected_result = torch.exp2(torch.floor(torch.log2(input_tensor))) |
44 |
| - smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32) |
45 |
| - expected_result = torch.max(expected_result, smallest_normal_fp32_value) |
| 23 | +def test_round_scale_down_to_power_of_2_valid_inputs( |
| 24 | + test_case: dict, |
| 25 | +): |
| 26 | + test_case_name, (input, expected_result) = test_case |
| 27 | + input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result) |
46 | 28 |
|
| 29 | + result = _round_scale_down_to_power_of_2(input_tensor) |
47 | 30 | assert torch.equal(
|
48 |
| - result, expected_result |
49 |
| - ), f"input: {input_tensor}, expected {expected_result}, but got {result}" |
| 31 | + result, expected_tensor |
| 32 | + ), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" |
50 | 33 |
|
51 | 34 |
|
52 | 35 | @pytest.mark.parametrize(
|
|
0 commit comments