Skip to content

Commit a88ad9a

Browse files
committed
Fixes to support NumPy 2.1.0
1 parent acdbb3c commit a88ad9a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/shmem4py/shmem.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,9 @@ def full(
997997
Valid hints are defined as enumerations in `MALLOC` and can be
998998
combined using the bitwise OR operator. Keyword argument only.
999999
"""
1000+
fill_value = np.array(fill_value)
10001001
if dtype is None:
1001-
dtype = np.array(fill_value).dtype
1002+
dtype = fill_value.dtype
10021003
a = new_array(shape, dtype, order, align=align, hints=hints, clear=False)
10031004
np.copyto(a, fill_value, casting='unsafe')
10041005
lib.shmem_barrier_all()

test/test_rma.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def testGet(self):
6363
nxpe = (mype + 1) % npes
6464
for t in types:
6565
src = shmem.full(1, mype, dtype=t)
66-
dst = np.full(1, -1, dtype=t)
66+
dst = np.full(1, np.array(-1), dtype=t)
6767
shmem.barrier_all()
6868
shmem.get(dst, src, nxpe)
6969
self.assertEqual(dst[0], nxpe)
@@ -148,7 +148,7 @@ def testGetNBI(self):
148148
nxpe = (mype + 1) % npes
149149
for t in types:
150150
src = shmem.full(1, mype, dtype=t)
151-
dst = np.full(1, -1, dtype=t)
151+
dst = np.full(1, np.array(-1), dtype=t)
152152
shmem.barrier_all()
153153
shmem.get_nbi(dst, src, nxpe)
154154
shmem.fence()

0 commit comments

Comments
 (0)