From cc0a602796af7579d96deee378a3e336402473c6 Mon Sep 17 00:00:00 2001 From: Marcin Rogowski Date: Fri, 14 Apr 2023 14:54:08 +0300 Subject: [PATCH] Accept start point --- src/shmem4py/shmem.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/shmem4py/shmem.py b/src/shmem4py/shmem.py index c8ba12c..cc5584a 100644 --- a/src/shmem4py/shmem.py +++ b/src/shmem4py/shmem.py @@ -872,9 +872,30 @@ def _ceildiv(p: int, q: int) -> int: # --- +def _parse_stride(st): + if isinstance(st, int): + stride = st + start = [0] + elif isinstance(st, tuple): + stride, start = st[1], list(st[0]) + + return stride, start + def _parse_rma(target, source, size=None, tst=1, sst=1): - tdata, tlen, ttype = _getbuffer(target, readonly=False) - sdata, slen, stype = _getbuffer(source, readonly=True) + if isinstance(tst, tuple): assert target.ndim == len(tst[0]) + if isinstance(sst, tuple): assert source.ndim == len(sst[0]) + tst, tstart = _parse_stride(tst) + sst, sstart = _parse_stride(sst) + + if tstart != [0]: + tdata, tlen, ttype = _getbuffer(target[*tstart[:-1],tstart[-1]:], readonly=False) + else: + tdata, tlen, ttype = _getbuffer(target, readonly=False) + + if sstart != [0]: + sdata, slen, stype = _getbuffer(source[*sstart[:-1],sstart[-1]:], readonly=True) + else: + sdata, slen, stype = _getbuffer(source, readonly=True) assert ttype == stype ctype = ttype @@ -884,8 +905,7 @@ def _parse_rma(target, source, size=None, tst=1, sst=1): if size is None: size = min(tsize, ssize) else: - assert size <= tsize - assert size <= ssize + assert size >= 0 return (ctype, tdata, sdata, size) @@ -901,6 +921,8 @@ def _shmem_rma(ctx, name, target, source, size, pe): def _shmem_irma(ctx, name, target, source, tst, sst, size, pe): ctype, target, source, size = _parse_rma(target, source, size, tst, sst) + tst, _ = _parse_stride(tst) + sst, _ = _parse_stride(sst) return _shmem(ctx, ctype, f'i{name}')(target, source, tst, sst, size, pe)