|
30 | 30 | #include "llvm/ADT/DenseMap.h"
|
31 | 31 | #include "llvm/ADT/TypeSwitch.h"
|
32 | 32 |
|
| 33 | +#include <numeric> |
| 34 | + |
33 | 35 | using namespace mlir;
|
34 | 36 | using namespace mlir::tosa;
|
35 | 37 |
|
@@ -1077,12 +1079,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
|
1077 | 1079 | << newShapeDim;
|
1078 | 1080 | }
|
1079 | 1081 |
|
1080 |
| - if (inputType.hasStaticShape() && outputType.hasStaticShape()) { |
| 1082 | + if (inputType.hasStaticShape()) { |
1081 | 1083 | int64_t inputElementsNum = inputType.getNumElements();
|
1082 |
| - int64_t outputElementsNum = outputType.getNumElements(); |
1083 |
| - if (inputElementsNum != outputElementsNum) { |
| 1084 | + if (outputType.hasStaticShape()) { |
| 1085 | + int64_t outputElementsNum = outputType.getNumElements(); |
| 1086 | + if (inputElementsNum != outputElementsNum) { |
| 1087 | + return emitOpError() << "cannot reshape " << inputElementsNum |
| 1088 | + << " elements into " << outputElementsNum; |
| 1089 | + } |
| 1090 | + } |
| 1091 | + |
| 1092 | + int64_t newShapeElementsNum = std::accumulate( |
| 1093 | + getNewShape().begin(), getNewShape().end(), 1LL, |
| 1094 | + [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); |
| 1095 | + bool isStaticNewShape = |
| 1096 | + llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; }); |
| 1097 | + if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || |
| 1098 | + (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { |
1084 | 1099 | return emitOpError() << "cannot reshape " << inputElementsNum
|
1085 |
| - << " elements into " << outputElementsNum; |
| 1100 | + << " elements into " << newShapeElementsNum; |
1086 | 1101 | }
|
1087 | 1102 |
|
1088 | 1103 | if ((int64_t)getNewShape().size() != outputType.getRank()) {
|
|
0 commit comments