Skip to content

Commit 77d004e

Browse files
test nan
1 parent c6bcac8 commit 77d004e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/float8/test_float8_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# "test_case_name": [input, expected result]
1616
("one", [1.0, 1.0]),
1717
("inf", [float("inf"), float("inf")]),
18+
("nan", [float("nan"), float("nan")]),
1819
("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]),
1920
("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
2021
("largest normal number", [2**127 * (2 - 2**-23), float("inf")]),
@@ -32,8 +33,9 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
3233
torch.tensor(expected_result).cuda(),
3334
)
3435
result = _round_scale_down_to_power_of_2(input_tensor)
35-
assert torch.equal(
36-
result, expected_tensor
36+
assert (
37+
torch.equal(result, expected_tensor)
38+
or (result.isnan() and expected_tensor.isnan())
3739
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"
3840

3941

0 commit comments

Comments
 (0)