Skip to content

Commit ebd3fd9

Browse files
committed
Use pointers
1 parent 0433b8e commit ebd3fd9

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

tests/test_cupy.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
lambda: Stream(non_blocking=True),
1313
lambda: Stream(null=True),
1414
lambda: Stream(ptds=True),
15-
pytest.param(
16-
lambda: 123,
17-
id="dlpack stream",
18-
marks=pytest.mark.skip(reason="segmentation fault reported (#326)")
19-
),
2015
],
2116
)
2217
def test_to_device_with_stream(make_stream):
@@ -32,3 +27,19 @@ def test_to_device_with_stream(make_stream):
3227
# device context.
3328
b = to_device(a, dev, stream=stream)
3429
assert device(b) == dev
30+
31+
32+
def test_to_device_with_dlpack_stream():
33+
devices = xp.__array_namespace_info__().devices()
34+
35+
a = xp.asarray([1, 2, 3])
36+
for dev in devices:
37+
# Streams are device-specific and must be created within
38+
# the context of the device...
39+
with dev:
40+
s1 = Stream()
41+
42+
# ... however, to_device() does not need to be inside the
43+
# device context.
44+
b = to_device(a, dev, stream=s1.ptr)
45+
assert device(b) == dev

0 commit comments

Comments
 (0)