Skip to content

Commit a3908ab

Browse files
committed
fix bug and warning
1 parent 6767196 commit a3908ab

9 files changed

+21
-18
lines changed

Diff for: src/EWBinOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ struct DeferredEWBinOp : public Deferred {
120120
auto av = dm.getDependent(builder, Registry::get(_a));
121121
auto bv = dm.getDependent(builder, Registry::get(_b));
122122

123-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
123+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
124124
auto outElemType =
125125
::imex::ndarray::toMLIR(builder, SHARPY::jit::getPTDType(_dtype));
126126
auto outTyp = aTyp.cloneWith(shape(), outElemType);

Diff for: src/EWUnyOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct DeferredEWUnyOp : public Deferred {
105105
jit::DepManager &dm) override {
106106
auto av = dm.getDependent(builder, Registry::get(_a));
107107

108-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
108+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
109109
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
110110

111111
auto ndOpId = sharpy(_op);

Diff for: src/IEWBinOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct DeferredIEWBinOp : public Deferred {
7171
auto av = dm.getDependent(builder, Registry::get(_a));
7272
auto bv = dm.getDependent(builder, Registry::get(_b));
7373

74-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
74+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
7575
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
7676

7777
auto binop = builder.create<::imex::ndarray::EWBinOp>(

Diff for: src/ManipOp.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct DeferredReshape : public Deferred {
4141
? ::mlir::IntegerAttr()
4242
: ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1);
4343

44-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
44+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
4545
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());
4646

4747
auto op =
@@ -106,7 +106,7 @@ struct DeferredAsType : public Deferred {
106106
// construct NDArrayType with same shape and given dtype
107107
::imex::ndarray::DType ndDType = dispatch<convDType>(dtype);
108108
auto mlirElType = ::imex::ndarray::toMLIR(builder, ndDType);
109-
auto arType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
109+
auto arType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType());
110110
if (!arType) {
111111
throw std::invalid_argument(
112112
"Encountered unexpected ndarray type in astype.");
@@ -157,7 +157,7 @@ struct DeferredToDevice : public Deferred {
157157
jit::DepManager &dm) override {
158158
auto av = dm.getDependent(builder, Registry::get(_a));
159159

160-
auto srcType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
160+
auto srcType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType());
161161
if (!srcType) {
162162
throw std::invalid_argument(
163163
"Encountered unexpected ndarray type in to_device.");

Diff for: src/NDArray.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ void NDArray::NDADeleter::operator()(NDArray *a) const {
113113
std::cerr << "sharpy fini: detected possible memory leak\n";
114114
} else {
115115
auto av = dm.addDependent(builder, a);
116-
builder.create<::imex::ndarray::DeleteOp>(loc, av);
116+
auto deleteOp = builder.create<::imex::ndarray::DeleteOp>(loc, av);
117+
deleteOp->setAttr("bufferization.manual_deallocation",
118+
builder.getUnitAttr());
117119
dm.drop(a->guid());
118120
}
119121
return false;

Diff for: src/ReduceOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct DeferredReduceOp : public Deferred {
6161
// FIXME reduction over individual dimensions is not supported
6262
auto av = dm.getDependent(builder, Registry::get(_a));
6363
// return type 0d with same dtype as input
64-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
64+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
6565
auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape());
6666
// reduction op
6767
auto mop = sharpy2mlir(_op);

Diff for: src/SetGetItem.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ struct DeferredGetItem : public Deferred {
277277
const auto &offs = _slc.offsets();
278278
const auto &sizes = shape();
279279
const auto &strides = _slc.strides();
280-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
280+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
281281
auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape());
282282

283283
// now we can create the NDArray op using the above Values

Diff for: src/idtr.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@ template <typename T> class ndarray {
641641
size_t size = lSize();
642642
id idx = firstLocalIndex();
643643
while (size--) {
644-
std::cout << "idx: " << idx[0] << ", " << idx[1] << std::endl;
645644
callback(idx);
646645
idx.next(_gShape);
647646
}
@@ -713,12 +712,13 @@ template <typename T> class WaitPermute {
713712
public:
714713
WaitPermute(SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
715714
SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
716-
std::vector<int64_t> &&axes, ndarray<T> &&output,
717-
std::vector<T> &&receiveBuffer, std::vector<int> &&receiveOffsets,
715+
std::vector<int64_t> &&axes, std::vector<int64_t> oGShape,
716+
ndarray<T> &&output, std::vector<T> &&receiveBuffer,
717+
std::vector<int> &&receiveOffsets,
718718
std::vector<int> &&receiveSizes)
719719
: tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
720-
axes(std::move(axes)), output(std::move(output)),
721-
receiveBuffer(std::move(receiveBuffer)),
720+
axes(std::move(axes)), oGShape(std::move(oGShape)),
721+
output(std::move(output)), receiveBuffer(std::move(receiveBuffer)),
722722
receiveOffsets(std::move(receiveOffsets)),
723723
receiveSizes(std::move(receiveSizes)) {}
724724

@@ -735,8 +735,6 @@ template <typename T> class WaitPermute {
735735
std::vector<size_t> receiveRankBufferCount(nRanks, 0);
736736
output.localIndices([&](const id &outputIndex) {
737737
id inputIndex = outputIndex.permute(axes);
738-
std::cout << "inputIndex: " << inputIndex[0] << ", " << inputIndex[1]
739-
<< std::endl;
740738
auto rank = getInputRank(parts, inputIndex[0]);
741739
auto &count = receiveRankBufferCount[rank];
742740
output[outputIndex] = receiveRankBuffer[rank][count++];
@@ -749,6 +747,7 @@ template <typename T> class WaitPermute {
749747
SHARPY::rank_type nRanks;
750748
std::vector<Parts> parts;
751749
std::vector<int64_t> axes;
750+
std::vector<int64_t> oGShape;
752751
ndarray<T> output;
753752
std::vector<T> receiveBuffer;
754753
std::vector<int> receiveOffsets;
@@ -893,8 +892,9 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
893892
receiveSizes.data(), receiveOffsets.data());
894893

895894
auto wait = WaitPermute(tc, hdl, nRanks, std::move(parts), std::move(axes),
896-
std::move(output), std::move(receiveBuffer),
897-
std::move(receiveOffsets), std::move(receiveSizes));
895+
std::move(oGShape), std::move(output),
896+
std::move(receiveBuffer), std::move(receiveOffsets),
897+
std::move(receiveSizes));
898898

899899
assert(parts.empty() && axes.empty() && receiveBuffer.empty() &&
900900
receiveOffsets.empty() && receiveSizes.empty());

Diff for: src/jit/mlir.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,7 @@ static const std::string cpu_pipeline =
692692
"canonicalize,"
693693
"imex-remove-temporaries,"
694694
"buffer-deallocation-pipeline,"
695+
"convert-bufferization-to-memref,"
695696
"func.func(convert-linalg-to-parallel-loops),"
696697
"func.func(scf-parallel-loop-fusion),"
697698
"drop-regions,"

0 commit comments

Comments
 (0)