Skip to content

Commit b56783c

Browse files
committed
Allow zero sized dimensiosn in tosa.reshape
1 parent 190b9ba commit b56783c

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+20-19
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ LogicalResult tosa::SliceOp::verify() {
890890
if (!inputType || !outputType)
891891
return success();
892892

893-
if (inputType.getRank() != outputType.getRank()) {
893+
if (inputType.getRank() != outputType.getRank()) {
894894
return emitOpError() << "rank of input (" << inputType.getRank()
895895
<< ") and output (" << outputType.getRank()
896896
<< ") must match";
@@ -1087,34 +1087,35 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
10871087
return emitOpError() << "cannot reshape " << inputElementsNum
10881088
<< " elements into " << outputElementsNum;
10891089
}
1090+
1091+
if ((int64_t)getNewShape().size() != outputType.getRank()) {
1092+
return emitOpError()
1093+
<< "rank of newShape (" << getNewShape().size()
1094+
<< ") and output (" << outputType.getRank() << ") must match";
1095+
}
1096+
1097+
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
1098+
if (getNewShape()[dim] != -1 &&
1099+
getNewShape()[dim] != outputType.getShape()[dim]) {
1100+
return emitOpError()
1101+
<< "newShape attribute (" << getNewShape()[dim]
1102+
<< ") does not match output type ("
1103+
<< outputType.getShape()[dim] << ") in dimension " << dim;
1104+
}
1105+
}
10901106
}
10911107

1108+
// AMD: Switched checks with > to >= to allow zero dimensions
10921109
int64_t newShapeElementsNum = std::accumulate(
10931110
getNewShape().begin(), getNewShape().end(), 1LL,
1094-
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1111+
[](int64_t acc, int64_t dim) { return (dim >= 0) ? acc * dim : acc; });
10951112
bool isStaticNewShape =
1096-
llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1113+
llvm::all_of(getNewShape(), [](int64_t s) { return s >= 0; });
10971114
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
10981115
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
10991116
return emitOpError() << "cannot reshape " << inputElementsNum
11001117
<< " elements into " << newShapeElementsNum;
11011118
}
1102-
1103-
if ((int64_t)getNewShape().size() != outputType.getRank()) {
1104-
return emitOpError() << "rank of newShape (" << getNewShape().size()
1105-
<< ") and output (" << outputType.getRank()
1106-
<< ") must match";
1107-
}
1108-
1109-
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
1110-
if (getNewShape()[dim] != -1 &&
1111-
getNewShape()[dim] != outputType.getShape()[dim]) {
1112-
return emitOpError()
1113-
<< "newShape attribute (" << getNewShape()[dim]
1114-
<< ") does not match output type (" << outputType.getShape()[dim]
1115-
<< ") in dimension " << dim;
1116-
}
1117-
}
11181119
}
11191120

11201121
int missingDims = llvm::count(getNewShape(), -1);

0 commit comments

Comments
 (0)