Skip to content

Commit 112def5

Browse files
committed
add xp, device tests
1 parent faaa2b6 commit 112def5

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/array_api_extra/_funcs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,10 @@ def pad(
582582
xp = array_namespace(x)
583583

584584
padded = xp.full(
585-
tuple(x + 2 * pad_width for x in x.shape), fill_value=value, dtype=x.dtype
585+
tuple(x + 2 * pad_width for x in x.shape),
586+
fill_value=value,
587+
dtype=x.dtype,
588+
device=_compat.device(x),
586589
)
587590
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
588591
return padded

tests/test_funcs.py

+8
Original file line numberDiff line numberDiff line change
@@ -408,3 +408,11 @@ def test_mode_not_implemented(self):
408408
a = xp.arange(3)
409409
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
410410
pad(a, 2, mode="edge")
411+
412+
def test_device(self):
413+
device = xp.Device("device1")
414+
a = xp.asarray(0.0, device=device)
415+
assert pad(a, 2).device == device
416+
417+
def test_xp(self):
418+
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))

0 commit comments

Comments
 (0)