@@ -872,9 +872,27 @@ def _ceildiv(p: int, q: int) -> int:
872
872
# ---
873
873
874
874
875
+ def _parse_stride (st ):
876
+ if isinstance (st , int ):
877
+ stride = st
878
+ start = [0 ]
879
+ elif isinstance (st , tuple ):
880
+ stride , start = st [1 ], list (st [0 ])
881
+
882
+ return stride , start
883
+
875
884
def _parse_rma (target , source , size = None , tst = 1 , sst = 1 ):
876
- tdata , tlen , ttype = _getbuffer (target , readonly = False )
877
- sdata , slen , stype = _getbuffer (source , readonly = True )
885
+ if isinstance (tst , tuple ): assert target .ndim == len (tst [0 ])
886
+ if isinstance (sst , tuple ): assert source .ndim == len (sst [0 ])
887
+ tst , tstart = _parse_stride (tst )
888
+ sst , sstart = _parse_stride (sst )
889
+
890
+ if tstart != [0 ]:
891
+ tdata , tlen , ttype = _getbuffer (target [* tstart [:- 1 ],tstart [- 1 ]:], readonly = False )
892
+ sdata , slen , stype = _getbuffer (source [* sstart [:- 1 ],sstart [- 1 ]:], readonly = True )
893
+ else :
894
+ tdata , tlen , ttype = _getbuffer (target , readonly = False )
895
+ sdata , slen , stype = _getbuffer (source , readonly = True )
878
896
879
897
assert ttype == stype
880
898
ctype = ttype
@@ -884,8 +902,7 @@ def _parse_rma(target, source, size=None, tst=1, sst=1):
884
902
if size is None :
885
903
size = min (tsize , ssize )
886
904
else :
887
- assert size <= tsize
888
- assert size <= ssize
905
+ assert size >= 0
889
906
890
907
return (ctype , tdata , sdata , size )
891
908
@@ -901,6 +918,8 @@ def _shmem_rma(ctx, name, target, source, size, pe):
901
918
902
919
def _shmem_irma (ctx , name , target , source , tst , sst , size , pe ):
903
920
ctype , target , source , size = _parse_rma (target , source , size , tst , sst )
921
+ tst , _ = _parse_stride (tst )
922
+ sst , _ = _parse_stride (sst )
904
923
return _shmem (ctx , ctype , f'i{ name } ' )(target , source , tst , sst , size , pe )
905
924
906
925
0 commit comments