File tree 1 file changed +7
-2
lines changed
1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change
1
+ import unittest
2
+
1
3
import pytest
2
4
import torch
3
5
6
8
7
9
# source for notable single-precision cases:
8
10
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
11
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
9
12
@pytest .mark .parametrize (
10
13
"test_case" ,
11
14
[
@@ -24,8 +27,10 @@ def test_round_scale_down_to_power_of_2_valid_inputs(
24
27
test_case : dict ,
25
28
):
26
29
test_case_name , (input , expected_result ) = test_case
27
- input_tensor , expected_tensor = torch .tensor (input ), torch .tensor (expected_result )
28
-
30
+ input_tensor , expected_tensor = (
31
+ torch .tensor (input ).cuda (),
32
+ torch .tensor (expected_result ).cuda (),
33
+ )
29
34
result = _round_scale_down_to_power_of_2 (input_tensor )
30
35
assert torch .equal (
31
36
result , expected_tensor
You can’t perform that action at this time.
0 commit comments