Skip to content

Commit d06819d

Browse files
committed
clean code
1 parent 9671713 commit d06819d

File tree

1 file changed

+16
-80
lines changed

1 file changed

+16
-80
lines changed

Diff for: src/idtr.cpp

+16-80
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
575575
oData.data(), oData.sizes(), oData.strides());
576576
}
577577

578+
namespace {
579+
578580
class id {
579581
public:
580582
id(size_t dims) : _values(dims) {}
@@ -609,45 +611,38 @@ class id {
609611
return id(std::move(new_values));
610612
}
611613

614+
void next(const int64_t *shape) {
615+
size_t i = _values.size();
616+
while (i--) {
617+
++_values[i];
618+
if (_values[i] < shape[i]) {
619+
return;
620+
}
621+
_values[i] = 0;
622+
}
623+
}
624+
612625
size_t size() { return _values.size(); }
613626

614627
private:
615628
std::vector<int64_t> _values;
616629
};
617630

618-
id &next_idx(id &idx, const int64_t *shape) {
619-
size_t i = idx.size();
620-
while (i--) {
621-
++idx[i];
622-
if (idx[i] < shape[i]) {
623-
return idx;
624-
}
625-
idx[i] = 0;
626-
}
627-
return idx;
628-
}
629-
630631
template <typename T> class ndarray {
631632
public:
632633
ndarray(int64_t nDims, int64_t *gShape, int64_t *gOffsets, void *lData,
633634
int64_t *lShape, int64_t *lStrides)
634635
: _nDims(nDims), _gShape(gShape), _gOffsets(gOffsets), _lData((T *)lData),
635636
_lShape(lShape), _lStrides(lStrides) {}
636-
// ndarray(std::vector<T> input, std::vector<int64_t> dims,
637-
// std::vector<int64_t> strides);
638-
639-
// id ids();
640-
// id local_ids();
641637

642638
id firstLocalIndex() const { return id(_nDims, _gOffsets); }
643639

644640
void localIndices(const std::function<void(const id &)> &callback) const {
645641
size_t size = lSize();
646642
id idx = firstLocalIndex();
647643
while (size--) {
648-
std::cout << "idx: " << idx[0] << "," << idx[1] << std::endl;
649644
callback(idx);
650-
next_idx(idx, _gShape);
645+
idx.next(_gShape);
651646
}
652647
}
653648

@@ -658,7 +653,6 @@ template <typename T> class ndarray {
658653
offset = (offset + localIdx[i]) * _lShape[i + 1];
659654
}
660655
offset += localIdx[_nDims - 1];
661-
std::cout << "offset: " << offset << std::endl;
662656
return offset;
663657
}
664658

@@ -714,47 +708,7 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
714708
return 0;
715709
}
716710

717-
// template <typename T>
718-
// void permute(const ndarray<T> &input, const ndarray<T> &output, uint64_t
719-
// nRanks,
720-
// const std::vector<Parts> &parts,
721-
// const std::vector<int64_t> &axes, ) {
722-
// std::vector<std::vector<T>> sendBuffer(nRanks); // alltoall
723-
724-
// input.permutedLocalIds(
725-
// [&](const id &idx) {
726-
// auto rank = getOutputRank(parts, idx[0]);
727-
// sendBuffer[rank].push_back(input[idx]);
728-
// },
729-
// axes);
730-
731-
// std::vector<int> receiveSizes(nRanks);
732-
// std::vector<int> receiveOffsets(nRanks);
733-
734-
// output.permutedLocalIds(
735-
// [&](const id &idx) {
736-
// auto rank = getInputRank(parts, idx[0]);
737-
// ++receiveSizes[rank];
738-
// },
739-
// axes);
740-
// for (size_t i = 1; i < nRanks; i++) {
741-
// receiveOffsets[i] = receiveOffsets[i - 1] + receiveSizes[i - 1];
742-
// }
743-
744-
// return sendBuffer;
745-
// }
746-
747-
// template <typename T>
748-
// void detranspose(std::vector<std::vector<T>> sendBuffer, ndarray<T> output,
749-
// std::vector<int64_t> axes, uint64_t nRank) {
750-
// std::vector<size_t> sendBufferIndex(sendBuffer.size());
751-
// for (auto idx : output) {
752-
// id in_idx = idx.permute(axes);
753-
// auto i = sendBufferIndex[in_idx[0]];
754-
// output[idx] = sendBuffer[in_idx[0]][i];
755-
// sendBufferIndex[in_idx[0]] = i + 1;
756-
// }
757-
// }
711+
} // namespace
758712

759713
/// @brief permute array
760714
/// We assume array is partitioned along the first dimension (only) and
@@ -826,7 +780,6 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
826780
}
827781

828782
// First we allgather the current and target partitioning
829-
830783
std::vector<Parts> parts(nRanks);
831784
parts[rank].iStart = iOffsPtr[0];
832785
parts[rank].iEnd = iOffsPtr[0] + iDataShapePtr[0];
@@ -840,7 +793,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
840793
tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64,
841794
SHARPY::REPLICATED);
842795

843-
// transpose
796+
// Transpose
844797
ndarray<T> input(iNDims, iGShapePtr, iOffsPtr, iDataPtr, iDataShapePtr,
845798
iDataStridesPtr);
846799
ndarray<T> output(oNDims, oGShapePtr, oOffsPtr, oDataPtr, oDataShapePtr,
@@ -882,28 +835,11 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
882835
}
883836
}
884837

885-
std::cout << "sendSizes: " << sendSizes[0] << std::endl;
886-
std::cout << "sendOffsets: " << sendOffsets[0] << std::endl;
887-
std::cout << "sendBuffer: ";
888-
for (int i = 0; i < sendBuffer.size(); ++i) {
889-
std::cout << sendBuffer[i] << ",";
890-
}
891-
std::cout << std::endl;
892-
893-
std::cout << "receiveSizes: " << receiveSizes[0] << std::endl;
894-
std::cout << "receiveOffsets: " << receiveOffsets[0] << std::endl;
895-
896838
auto hdl = tc->alltoall(sendBuffer.data(), sendSizes.data(),
897839
sendOffsets.data(), sharpytype, receiveBuffer.data(),
898840
receiveSizes.data(), receiveOffsets.data());
899841
tc->wait(hdl);
900842

901-
std::cout << "receiveBuffer: ";
902-
for (int i = 0; i < receiveBuffer.size(); ++i) {
903-
std::cout << receiveBuffer[i] << ",";
904-
}
905-
std::cout << std::endl;
906-
907843
{
908844
std::vector<std::vector<T>> receiveRankBuffer(nRanks);
909845
for (int64_t rank = 0; rank < nRanks; ++rank) {

0 commit comments

Comments
 (0)