Skip to content

Commit c6bcac8

Browse files
run tests on gpu
1 parent 40166e1 commit c6bcac8

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

test/float8/test_float8_utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import unittest
2+
13
import pytest
24
import torch
35

@@ -6,6 +8,7 @@
68

79
# source for notable single-precision cases:
810
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
11+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
912
@pytest.mark.parametrize(
1013
"test_case",
1114
[
@@ -24,8 +27,10 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
2427
test_case: dict,
2528
):
2629
test_case_name, (input, expected_result) = test_case
27-
input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result)
28-
30+
input_tensor, expected_tensor = (
31+
torch.tensor(input).cuda(),
32+
torch.tensor(expected_result).cuda(),
33+
)
2934
result = _round_scale_down_to_power_of_2(input_tensor)
3035
assert torch.equal(
3136
result, expected_tensor

0 commit comments

Comments
 (0)