Skip to content

Commit

Permalink
Add runtime asserts for aten.reflection.pad_2d (#4057)
Browse files Browse the repository at this point in the history
Add runtime asserts to check padding constraints of
aten.reflection.pad_2d for dynamic dims
  • Loading branch information
praveen-g-ctt authored Mar 3, 2025
1 parent 1aacb46 commit 32aff8c
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,21 @@ class ConvertAtenReflectionPad2dOp
Value hDimSize = inputShape[hDim];
Value vDimSize = inputShape[vDim];

assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
"Left padding too large");
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
"Right padding too large");
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
"Top padding too large");
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
"Bottom padding too large");
auto verifyPadding = [&](int64_t padArgument, int64_t dim,
StringRef errorMessage) {
auto padValue = rewriter.create<arith::ConstantIndexOp>(loc, padArgument);
Value index = rewriter.create<arith::ConstantIndexOp>(loc, dim);
Value shapeDim = rewriter.create<tensor::DimOp>(loc, input, index);
Value cmpPred = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, padValue, shapeDim);
rewriter.create<cf::AssertOp>(loc, cmpPred,
rewriter.getStringAttr(errorMessage));
};

verifyPadding(getHPadArgument(LEFT), hDim, "Left padding too large");
verifyPadding(getHPadArgument(RIGHT), hDim, "Right padding too large");
verifyPadding(getVPadArgument(TOP), vDim, "Top padding too large");
verifyPadding(getVPadArgument(BOTTOM), vDim, "Bottom padding too large");

Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Expand Down

0 comments on commit 32aff8c

Please sign in to comment.