Skip to content

Commit

Permalink
Accept start point
Browse files Browse the repository at this point in the history
  • Loading branch information
mrogowski committed Apr 14, 2023
1 parent 903a743 commit e5b966d
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/shmem4py/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,9 +872,27 @@ def _ceildiv(p: int, q: int) -> int:
# ---


def _parse_stride(st):
if isinstance(st, int):
stride = st
start = [0]
elif isinstance(st, tuple):
stride, start = st[1], list(st[0])

return stride, start

def _parse_rma(target, source, size=None, tst=1, sst=1):
tdata, tlen, ttype = _getbuffer(target, readonly=False)
sdata, slen, stype = _getbuffer(source, readonly=True)
if isinstance(tst, tuple): assert target.ndim == len(tst[0])
if isinstance(sst, tuple): assert source.ndim == len(sst[0])
tst, tstart = _parse_stride(tst)
sst, sstart = _parse_stride(sst)

if tstart != [0]:
tdata, tlen, ttype = _getbuffer(target[*tstart[:-1],tstart[-1]:], readonly=False)
sdata, slen, stype = _getbuffer(source[*sstart[:-1],sstart[-1]:], readonly=True)
else:
tdata, tlen, ttype = _getbuffer(target, readonly=False)
sdata, slen, stype = _getbuffer(source, readonly=True)

assert ttype == stype
ctype = ttype
Expand All @@ -884,8 +902,7 @@ def _parse_rma(target, source, size=None, tst=1, sst=1):
if size is None:
size = min(tsize, ssize)
else:
assert size <= tsize
assert size <= ssize
assert size >= 0

return (ctype, tdata, sdata, size)

Expand All @@ -901,6 +918,8 @@ def _shmem_rma(ctx, name, target, source, size, pe):

def _shmem_irma(ctx, name, target, source, tst, sst, size, pe):
ctype, target, source, size = _parse_rma(target, source, size, tst, sst)
tst, _ = _parse_stride(tst)
sst, _ = _parse_stride(sst)
return _shmem(ctx, ctype, f'i{name}')(target, source, tst, sst, size, pe)


Expand Down

0 comments on commit e5b966d

Please sign in to comment.