Skip to content

Commit 190b9ba

Browse files
committed
[AutoBump] Merge with fixes of e6eb94d (Sep 20)
2 parents a00dc5b + e6eb94d commit 190b9ba

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include "llvm/ADT/DenseMap.h"
3131
#include "llvm/ADT/TypeSwitch.h"
3232

33+
#include <numeric>
34+
3335
using namespace mlir;
3436
using namespace mlir::tosa;
3537

@@ -1077,12 +1079,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
10771079
<< newShapeDim;
10781080
}
10791081

1080-
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1082+
if (inputType.hasStaticShape()) {
10811083
int64_t inputElementsNum = inputType.getNumElements();
1082-
int64_t outputElementsNum = outputType.getNumElements();
1083-
if (inputElementsNum != outputElementsNum) {
1084+
if (outputType.hasStaticShape()) {
1085+
int64_t outputElementsNum = outputType.getNumElements();
1086+
if (inputElementsNum != outputElementsNum) {
1087+
return emitOpError() << "cannot reshape " << inputElementsNum
1088+
<< " elements into " << outputElementsNum;
1089+
}
1090+
}
1091+
1092+
int64_t newShapeElementsNum = std::accumulate(
1093+
getNewShape().begin(), getNewShape().end(), 1LL,
1094+
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1095+
bool isStaticNewShape =
1096+
llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1097+
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1098+
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
10841099
return emitOpError() << "cannot reshape " << inputElementsNum
1085-
<< " elements into " << outputElementsNum;
1100+
<< " elements into " << newShapeElementsNum;
10861101
}
10871102

10881103
if ((int64_t)getNewShape().size() != outputType.getRank()) {

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
404404

405405
// -----
406406

407+
func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
408+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
409+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
410+
return
411+
}
412+
413+
// -----
414+
415+
func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
416+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
417+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
418+
return
419+
}
420+
421+
// -----
422+
407423
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
408424
// expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
409425
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>

0 commit comments

Comments
 (0)