diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b8c20bc73f65..db5afbf06c7b 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -404,14 +404,21 @@ class ConvertAtenReflectionPad2dOp Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; - assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && - "Left padding too large"); - assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && - "Right padding too large"); - assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && - "Top padding too large"); - assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && - "Bottom padding too large"); + auto verifyPadding = [&](int64_t padArgument, int64_t dim, + StringRef errorMessage) { + auto padValue = rewriter.create(loc, padArgument); + Value index = rewriter.create(loc, dim); + Value shapeDim = rewriter.create(loc, input, index); + Value cmpPred = rewriter.create( + loc, arith::CmpIPredicate::sle, padValue, shapeDim); + rewriter.create(loc, cmpPred, + rewriter.getStringAttr(errorMessage)); + }; + + verifyPadding(getHPadArgument(LEFT), hDim, "Left padding too large"); + verifyPadding(getHPadArgument(RIGHT), hDim, "Right padding too large"); + verifyPadding(getVPadArgument(TOP), vDim, "Top padding too large"); + verifyPadding(getVPadArgument(BOTTOM), vDim, "Bottom padding too large"); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType);