@@ -268,6 +268,21 @@ void unpack(void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
268268 });
269269}
270270
271+ // / copy contiguous block of data into a possibly strided array
272+ void unpack1 (void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
273+ const int64_t *strides, uint64_t ndim, void *out) {
274+ if (!in || !sizes || !strides || !out) {
275+ return ;
276+ }
277+ dispatch (dtype, out, [sizes, strides, ndim, in](auto *out_) {
278+ auto in_ = static_cast <decltype (out_)>(in);
279+ SHARPY::forall (0 , out_, sizes, strides, ndim, [&in_](auto *out) {
280+ *out = *in_;
281+ ++in_;
282+ });
283+ });
284+ }
285+
271286template <typename T>
272287void copy_ (uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
273288 const int64_t *strides, const uint64_t *chunks, uint64_t nd,
@@ -489,21 +504,41 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
489504 }
490505 }
491506
507+ int64_t oStride = std::accumulate (oDataStridesPtr, oDataStridesPtr + oNDims,
508+ 1 , std::multiplies<int64_t >());
509+ void *rBuff = oDataPtr;
510+ if (oStride != 1 ) {
511+ rBuff = new char [sizeof_dtype (sharpytype) * myOSz];
512+ }
513+
492514 SHARPY::Buffer sendbuff (totSSz * sizeof_dtype (sharpytype), 2 );
493515 bufferizeN (iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N,
494516 lsOffs.data (), lsEnds.data (), sendbuff.data ());
495517 auto hdl = tc->alltoall (sendbuff.data (), sszs.data (), soffs.data (),
496- sharpytype, oDataPtr , rszs.data (), roffs.data ());
518+ sharpytype, rBuff , rszs.data (), roffs.data ());
497519
498520 if (no_async) {
499521 tc->wait (hdl);
522+ if (oStride != 1 ) {
523+ unpack1 (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
524+ oDataPtr);
525+ delete[] (char *)rBuff;
526+ }
500527 return nullptr ;
501528 }
502529
503- auto wait = [tc = tc, hdl = hdl, sendbuff = std::move (sendbuff),
504- sszs = std::move (sszs), soffs = std::move (soffs),
505- rszs = std::move (rszs),
506- roffs = std::move (roffs)]() { tc->wait (hdl); };
530+ auto wait = [tc, hdl, oStride, rBuff, sharpytype, oDataShapePtr,
531+ oDataStridesPtr, oNDims, oDataPtr,
532+ sendbuff = std::move (sendbuff), sszs = std::move (sszs),
533+ soffs = std::move (soffs), rszs = std::move (rszs),
534+ roffs = std::move (roffs)]() {
535+ tc->wait (hdl);
536+ if (oStride != 1 ) {
537+ unpack1 (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
538+ oDataPtr);
539+ delete[] (char *)rBuff;
540+ }
541+ };
507542 assert (sendbuff.empty () && sszs.empty () && soffs.empty () && rszs.empty () &&
508543 roffs.empty ());
509544 return mkWaitHandle (std::move (wait));
0 commit comments