|
30 | 30 | #include "llvm/ADT/DenseMap.h"
|
31 | 31 | #include "llvm/ADT/TypeSwitch.h"
|
32 | 32 |
|
| 33 | +#include <numeric> |
| 34 | + |
33 | 35 | using namespace mlir;
|
34 | 36 | using namespace mlir::tosa;
|
35 | 37 |
|
@@ -888,7 +890,7 @@ LogicalResult tosa::SliceOp::verify() {
|
888 | 890 | if (!inputType || !outputType)
|
889 | 891 | return success();
|
890 | 892 |
|
891 |
| - if (inputType.getRank() != outputType.getRank()) { |
| 893 | + if (inputType.getRank() != outputType.getRank()) { |
892 | 894 | return emitOpError() << "rank of input (" << inputType.getRank()
|
893 | 895 | << ") and output (" << outputType.getRank()
|
894 | 896 | << ") must match";
|
@@ -1077,28 +1079,42 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
|
1077 | 1079 | << newShapeDim;
|
1078 | 1080 | }
|
1079 | 1081 |
|
1080 |
| - if (inputType.hasStaticShape() && outputType.hasStaticShape()) { |
| 1082 | + if (inputType.hasStaticShape()) { |
1081 | 1083 | int64_t inputElementsNum = inputType.getNumElements();
|
1082 |
| - int64_t outputElementsNum = outputType.getNumElements(); |
1083 |
| - if (inputElementsNum != outputElementsNum) { |
1084 |
| - return emitOpError() << "cannot reshape " << inputElementsNum |
1085 |
| - << " elements into " << outputElementsNum; |
1086 |
| - } |
| 1084 | + if (outputType.hasStaticShape()) { |
| 1085 | + int64_t outputElementsNum = outputType.getNumElements(); |
| 1086 | + if (inputElementsNum != outputElementsNum) { |
| 1087 | + return emitOpError() << "cannot reshape " << inputElementsNum |
| 1088 | + << " elements into " << outputElementsNum; |
| 1089 | + } |
| 1090 | + |
| 1091 | + if ((int64_t)getNewShape().size() != outputType.getRank()) { |
| 1092 | + return emitOpError() |
| 1093 | + << "rank of newShape (" << getNewShape().size() |
| 1094 | + << ") and output (" << outputType.getRank() << ") must match"; |
| 1095 | + } |
1087 | 1096 |
|
1088 |
| - if ((int64_t)getNewShape().size() != outputType.getRank()) { |
1089 |
| - return emitOpError() << "rank of newShape (" << getNewShape().size() |
1090 |
| - << ") and output (" << outputType.getRank() |
1091 |
| - << ") must match"; |
| 1097 | + for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { |
| 1098 | + if (getNewShape()[dim] != -1 && |
| 1099 | + getNewShape()[dim] != outputType.getShape()[dim]) { |
| 1100 | + return emitOpError() |
| 1101 | + << "newShape attribute (" << getNewShape()[dim] |
| 1102 | + << ") does not match output type (" |
| 1103 | + << outputType.getShape()[dim] << ") in dimension " << dim; |
| 1104 | + } |
| 1105 | + } |
1092 | 1106 | }
|
1093 | 1107 |
|
1094 |
| - for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { |
1095 |
| - if (getNewShape()[dim] != -1 && |
1096 |
| - getNewShape()[dim] != outputType.getShape()[dim]) { |
1097 |
| - return emitOpError() |
1098 |
| - << "newShape attribute (" << getNewShape()[dim] |
1099 |
| - << ") does not match output type (" << outputType.getShape()[dim] |
1100 |
| - << ") in dimension " << dim; |
1101 |
| - } |
| 1108 | + // AMD: Switched checks with > to >= to allow zero dimensions |
| 1109 | + int64_t newShapeElementsNum = std::accumulate( |
| 1110 | + getNewShape().begin(), getNewShape().end(), 1LL, |
| 1111 | + [](int64_t acc, int64_t dim) { return (dim >= 0) ? acc * dim : acc; }); |
| 1112 | + bool isStaticNewShape = |
| 1113 | + llvm::all_of(getNewShape(), [](int64_t s) { return s >= 0; }); |
| 1114 | + if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || |
| 1115 | + (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { |
| 1116 | + return emitOpError() << "cannot reshape " << inputElementsNum |
| 1117 | + << " elements into " << newShapeElementsNum; |
1102 | 1118 | }
|
1103 | 1119 | }
|
1104 | 1120 |
|
|
0 commit comments