@@ -575,6 +575,8 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
575
575
oData.data (), oData.sizes (), oData.strides ());
576
576
}
577
577
578
+ namespace {
579
+
578
580
class id {
579
581
public:
580
582
id (size_t dims) : _values(dims) {}
@@ -609,45 +611,38 @@ class id {
609
611
return id (std::move (new_values));
610
612
}
611
613
614
+ void next (const int64_t *shape) {
615
+ size_t i = _values.size ();
616
+ while (i--) {
617
+ ++_values[i];
618
+ if (_values[i] < shape[i]) {
619
+ return ;
620
+ }
621
+ _values[i] = 0 ;
622
+ }
623
+ }
624
+
612
625
size_t size () { return _values.size (); }
613
626
614
627
private:
615
628
std::vector<int64_t > _values;
616
629
};
617
630
618
- id &next_idx (id &idx, const int64_t *shape) {
619
- size_t i = idx.size ();
620
- while (i--) {
621
- ++idx[i];
622
- if (idx[i] < shape[i]) {
623
- return idx;
624
- }
625
- idx[i] = 0 ;
626
- }
627
- return idx;
628
- }
629
-
630
631
template <typename T> class ndarray {
631
632
public:
632
633
ndarray (int64_t nDims, int64_t *gShape , int64_t *gOffsets , void *lData,
633
634
int64_t *lShape, int64_t *lStrides)
634
635
: _nDims(nDims), _gShape(gShape ), _gOffsets(gOffsets ), _lData((T *)lData),
635
636
_lShape (lShape), _lStrides(lStrides) {}
636
- // ndarray(std::vector<T> input, std::vector<int64_t> dims,
637
- // std::vector<int64_t> strides);
638
-
639
- // id ids();
640
- // id local_ids();
641
637
642
638
id firstLocalIndex () const { return id (_nDims, _gOffsets); }
643
639
644
640
void localIndices (const std::function<void (const id &)> &callback) const {
645
641
size_t size = lSize ();
646
642
id idx = firstLocalIndex ();
647
643
while (size--) {
648
- std::cout << " idx: " << idx[0 ] << " ," << idx[1 ] << std::endl;
649
644
callback (idx);
650
- next_idx ( idx, _gShape);
645
+ idx. next ( _gShape);
651
646
}
652
647
}
653
648
@@ -658,7 +653,6 @@ template <typename T> class ndarray {
658
653
offset = (offset + localIdx[i]) * _lShape[i + 1 ];
659
654
}
660
655
offset += localIdx[_nDims - 1 ];
661
- std::cout << " offset: " << offset << std::endl;
662
656
return offset;
663
657
}
664
658
@@ -714,47 +708,7 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
714
708
return 0 ;
715
709
}
716
710
717
- // template <typename T>
718
- // void permute(const ndarray<T> &input, const ndarray<T> &output, uint64_t
719
- // nRanks,
720
- // const std::vector<Parts> &parts,
721
- // const std::vector<int64_t> &axes, ) {
722
- // std::vector<std::vector<T>> sendBuffer(nRanks); // alltoall
723
-
724
- // input.permutedLocalIds(
725
- // [&](const id &idx) {
726
- // auto rank = getOutputRank(parts, idx[0]);
727
- // sendBuffer[rank].push_back(input[idx]);
728
- // },
729
- // axes);
730
-
731
- // std::vector<int> receiveSizes(nRanks);
732
- // std::vector<int> receiveOffsets(nRanks);
733
-
734
- // output.permutedLocalIds(
735
- // [&](const id &idx) {
736
- // auto rank = getInputRank(parts, idx[0]);
737
- // ++receiveSizes[rank];
738
- // },
739
- // axes);
740
- // for (size_t i = 1; i < nRanks; i++) {
741
- // receiveOffsets[i] = receiveOffsets[i - 1] + receiveSizes[i - 1];
742
- // }
743
-
744
- // return sendBuffer;
745
- // }
746
-
747
- // template <typename T>
748
- // void detranspose(std::vector<std::vector<T>> sendBuffer, ndarray<T> output,
749
- // std::vector<int64_t> axes, uint64_t nRank) {
750
- // std::vector<size_t> sendBufferIndex(sendBuffer.size());
751
- // for (auto idx : output) {
752
- // id in_idx = idx.permute(axes);
753
- // auto i = sendBufferIndex[in_idx[0]];
754
- // output[idx] = sendBuffer[in_idx[0]][i];
755
- // sendBufferIndex[in_idx[0]] = i + 1;
756
- // }
757
- // }
711
+ } // namespace
758
712
759
713
// / @brief permute array
760
714
// / We assume array is partitioned along the first dimension (only) and
@@ -826,7 +780,6 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
826
780
}
827
781
828
782
// First we allgather the current and target partitioning
829
-
830
783
std::vector<Parts> parts (nRanks);
831
784
parts[rank].iStart = iOffsPtr[0 ];
832
785
parts[rank].iEnd = iOffsPtr[0 ] + iDataShapePtr[0 ];
@@ -840,7 +793,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
840
793
tc->gather (parts.data (), counts.data (), dspl.data (), SHARPY::INT64,
841
794
SHARPY::REPLICATED);
842
795
843
- // transpose
796
+ // Transpose
844
797
ndarray<T> input (iNDims, iGShapePtr, iOffsPtr, iDataPtr, iDataShapePtr,
845
798
iDataStridesPtr);
846
799
ndarray<T> output (oNDims, oGShapePtr, oOffsPtr, oDataPtr, oDataShapePtr,
@@ -882,28 +835,11 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
882
835
}
883
836
}
884
837
885
- std::cout << " sendSizes: " << sendSizes[0 ] << std::endl;
886
- std::cout << " sendOffsets: " << sendOffsets[0 ] << std::endl;
887
- std::cout << " sendBuffer: " ;
888
- for (int i = 0 ; i < sendBuffer.size (); ++i) {
889
- std::cout << sendBuffer[i] << " ," ;
890
- }
891
- std::cout << std::endl;
892
-
893
- std::cout << " receiveSizes: " << receiveSizes[0 ] << std::endl;
894
- std::cout << " receiveOffsets: " << receiveOffsets[0 ] << std::endl;
895
-
896
838
auto hdl = tc->alltoall (sendBuffer.data (), sendSizes.data (),
897
839
sendOffsets.data (), sharpytype, receiveBuffer.data (),
898
840
receiveSizes.data (), receiveOffsets.data ());
899
841
tc->wait (hdl);
900
842
901
- std::cout << " receiveBuffer: " ;
902
- for (int i = 0 ; i < receiveBuffer.size (); ++i) {
903
- std::cout << receiveBuffer[i] << " ," ;
904
- }
905
- std::cout << std::endl;
906
-
907
843
{
908
844
std::vector<std::vector<T>> receiveRankBuffer (nRanks);
909
845
for (int64_t rank = 0 ; rank < nRanks; ++rank) {
0 commit comments