@@ -890,7 +890,7 @@ LogicalResult tosa::SliceOp::verify() {
890
890
if (!inputType || !outputType)
891
891
return success ();
892
892
893
- if (inputType.getRank () != outputType.getRank ()) {
893
+ if (inputType.getRank () != outputType.getRank ()) {
894
894
return emitOpError () << " rank of input (" << inputType.getRank ()
895
895
<< " ) and output (" << outputType.getRank ()
896
896
<< " ) must match" ;
@@ -1087,34 +1087,35 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
1087
1087
return emitOpError () << " cannot reshape " << inputElementsNum
1088
1088
<< " elements into " << outputElementsNum;
1089
1089
}
1090
+
1091
+ if ((int64_t )getNewShape ().size () != outputType.getRank ()) {
1092
+ return emitOpError ()
1093
+ << " rank of newShape (" << getNewShape ().size ()
1094
+ << " ) and output (" << outputType.getRank () << " ) must match" ;
1095
+ }
1096
+
1097
+ for (int64_t dim = 0 ; dim < outputType.getRank (); ++dim) {
1098
+ if (getNewShape ()[dim] != -1 &&
1099
+ getNewShape ()[dim] != outputType.getShape ()[dim]) {
1100
+ return emitOpError ()
1101
+ << " newShape attribute (" << getNewShape ()[dim]
1102
+ << " ) does not match output type ("
1103
+ << outputType.getShape ()[dim] << " ) in dimension " << dim;
1104
+ }
1105
+ }
1090
1106
}
1091
1107
1108
+ // AMD: Switched checks with > to >= to allow zero dimensions
1092
1109
int64_t newShapeElementsNum = std::accumulate (
1093
1110
getNewShape ().begin (), getNewShape ().end (), 1LL ,
1094
- [](int64_t acc, int64_t dim) { return (dim > 0 ) ? acc * dim : acc; });
1111
+ [](int64_t acc, int64_t dim) { return (dim >= 0 ) ? acc * dim : acc; });
1095
1112
bool isStaticNewShape =
1096
- llvm::all_of (getNewShape (), [](int64_t s) { return s > 0 ; });
1113
+ llvm::all_of (getNewShape (), [](int64_t s) { return s >= 0 ; });
1097
1114
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1098
1115
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1099
1116
return emitOpError () << " cannot reshape " << inputElementsNum
1100
1117
<< " elements into " << newShapeElementsNum;
1101
1118
}
1102
-
1103
- if ((int64_t )getNewShape ().size () != outputType.getRank ()) {
1104
- return emitOpError () << " rank of newShape (" << getNewShape ().size ()
1105
- << " ) and output (" << outputType.getRank ()
1106
- << " ) must match" ;
1107
- }
1108
-
1109
- for (int64_t dim = 0 ; dim < outputType.getRank (); ++dim) {
1110
- if (getNewShape ()[dim] != -1 &&
1111
- getNewShape ()[dim] != outputType.getShape ()[dim]) {
1112
- return emitOpError ()
1113
- << " newShape attribute (" << getNewShape ()[dim]
1114
- << " ) does not match output type (" << outputType.getShape ()[dim]
1115
- << " ) in dimension " << dim;
1116
- }
1117
- }
1118
1119
}
1119
1120
1120
1121
int missingDims = llvm::count (getNewShape (), -1 );
0 commit comments