Skip to content

Commit d061bda

Browse files
authored
Merge pull request #422 from Xilinx/bump_to_e6eb94d3
[AutoBump] Merge with fixes of e6eb94d (Sep 20) (8)
2 parents d554286 + b56783c commit d061bda

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+35-19
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include "llvm/ADT/DenseMap.h"
3131
#include "llvm/ADT/TypeSwitch.h"
3232

33+
#include <numeric>
34+
3335
using namespace mlir;
3436
using namespace mlir::tosa;
3537

@@ -888,7 +890,7 @@ LogicalResult tosa::SliceOp::verify() {
888890
if (!inputType || !outputType)
889891
return success();
890892

891-
if (inputType.getRank() != outputType.getRank()) {
893+
if (inputType.getRank() != outputType.getRank()) {
892894
return emitOpError() << "rank of input (" << inputType.getRank()
893895
<< ") and output (" << outputType.getRank()
894896
<< ") must match";
@@ -1077,28 +1079,42 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
10771079
<< newShapeDim;
10781080
}
10791081

1080-
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1082+
if (inputType.hasStaticShape()) {
10811083
int64_t inputElementsNum = inputType.getNumElements();
1082-
int64_t outputElementsNum = outputType.getNumElements();
1083-
if (inputElementsNum != outputElementsNum) {
1084-
return emitOpError() << "cannot reshape " << inputElementsNum
1085-
<< " elements into " << outputElementsNum;
1086-
}
1084+
if (outputType.hasStaticShape()) {
1085+
int64_t outputElementsNum = outputType.getNumElements();
1086+
if (inputElementsNum != outputElementsNum) {
1087+
return emitOpError() << "cannot reshape " << inputElementsNum
1088+
<< " elements into " << outputElementsNum;
1089+
}
1090+
1091+
if ((int64_t)getNewShape().size() != outputType.getRank()) {
1092+
return emitOpError()
1093+
<< "rank of newShape (" << getNewShape().size()
1094+
<< ") and output (" << outputType.getRank() << ") must match";
1095+
}
10871096

1088-
if ((int64_t)getNewShape().size() != outputType.getRank()) {
1089-
return emitOpError() << "rank of newShape (" << getNewShape().size()
1090-
<< ") and output (" << outputType.getRank()
1091-
<< ") must match";
1097+
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
1098+
if (getNewShape()[dim] != -1 &&
1099+
getNewShape()[dim] != outputType.getShape()[dim]) {
1100+
return emitOpError()
1101+
<< "newShape attribute (" << getNewShape()[dim]
1102+
<< ") does not match output type ("
1103+
<< outputType.getShape()[dim] << ") in dimension " << dim;
1104+
}
1105+
}
10921106
}
10931107

1094-
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
1095-
if (getNewShape()[dim] != -1 &&
1096-
getNewShape()[dim] != outputType.getShape()[dim]) {
1097-
return emitOpError()
1098-
<< "newShape attribute (" << getNewShape()[dim]
1099-
<< ") does not match output type (" << outputType.getShape()[dim]
1100-
<< ") in dimension " << dim;
1101-
}
1108+
// AMD: Switched checks with > to >= to allow zero dimensions
1109+
int64_t newShapeElementsNum = std::accumulate(
1110+
getNewShape().begin(), getNewShape().end(), 1LL,
1111+
[](int64_t acc, int64_t dim) { return (dim >= 0) ? acc * dim : acc; });
1112+
bool isStaticNewShape =
1113+
llvm::all_of(getNewShape(), [](int64_t s) { return s >= 0; });
1114+
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1115+
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1116+
return emitOpError() << "cannot reshape " << inputElementsNum
1117+
<< " elements into " << newShapeElementsNum;
11021118
}
11031119
}
11041120

mlir/test/Dialect/Tosa/invalid.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
404404

405405
// -----
406406

407+
func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
408+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
409+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
410+
return
411+
}
412+
413+
// -----
414+
415+
func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
416+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
417+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
418+
return
419+
}
420+
421+
// -----
422+
407423
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
408424
// expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
409425
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>

0 commit comments

Comments
 (0)