Skip to content

Commit 6c6bcf2

Browse files
authored
Merge pull request #326 from crusaderky/cupy_to_device
TST: fix cupy `to_device` test on multiple devices
2 parents 559c82b + ebd3fd9 commit 6c6bcf2

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

tests/test_cupy.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,41 @@
55
from cupy.cuda import Stream
66

77

8-
def test_to_device_with_stream():
8+
@pytest.mark.parametrize(
9+
"make_stream",
10+
[
11+
lambda: Stream(),
12+
lambda: Stream(non_blocking=True),
13+
lambda: Stream(null=True),
14+
lambda: Stream(ptds=True),
15+
],
16+
)
17+
def test_to_device_with_stream(make_stream):
918
devices = xp.__array_namespace_info__().devices()
10-
streams = [
11-
Stream(),
12-
Stream(non_blocking=True),
13-
Stream(null=True),
14-
Stream(ptds=True),
15-
123, # dlpack stream
16-
]
1719

1820
a = xp.asarray([1, 2, 3])
1921
for dev in devices:
20-
for stream in streams:
21-
b = to_device(a, dev, stream=stream)
22-
assert device(b) == dev
22+
# Streams are device-specific and must be created within
23+
# the context of the device...
24+
with dev:
25+
stream = make_stream()
26+
# ... however, to_device() does not need to be inside the
27+
# device context.
28+
b = to_device(a, dev, stream=stream)
29+
assert device(b) == dev
30+
31+
32+
def test_to_device_with_dlpack_stream():
33+
devices = xp.__array_namespace_info__().devices()
34+
35+
a = xp.asarray([1, 2, 3])
36+
for dev in devices:
37+
# Streams are device-specific and must be created within
38+
# the context of the device...
39+
with dev:
40+
s1 = Stream()
41+
42+
# ... however, to_device() does not need to be inside the
43+
# device context.
44+
b = to_device(a, dev, stream=s1.ptr)
45+
assert device(b) == dev

0 commit comments

Comments
 (0)