Skip to content

Commit

Permalink
Make linearize_index to have SAME basis
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Dec 3, 2024
1 parent d04d2bc commit 7a6321f
Showing 1 changed file with 31 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,26 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
return indicesAfterCollapsing;
}

static bool hasDelinearizeIndexVal(OperandRange vals) {
return llvm::any_of(vals, [](OpFoldResult index) {
static FailureOr<SmallVector<OpFoldResult>> hasDelinearizeIndexVal(
OperandRange vals) {
Operation *delinearizeIndexOp = nullptr;
for (OpFoldResult index : vals) {
if (auto val = dyn_cast<Value>(index)) {
if (dyn_cast_if_present<affine::AffineDelinearizeIndexOp>(
val.getDefiningOp()))
return true;
val.getDefiningOp())) {
if (delinearizeIndexOp && (delinearizeIndexOp != val.getDefiningOp())) {
return failure();
}
delinearizeIndexOp = val.getDefiningOp();
}
}
return false;
});
}
return cast<affine::AffineDelinearizeIndexOp>(delinearizeIndexOp)
.getMixedBasis();
// return llvm::any_of(vals, [](OpFoldResult index) {

// return false;
// });
}

/// Rewrites contiguous row-major vector.transfer_read ops by inserting
Expand Down Expand Up @@ -201,14 +212,16 @@ class FlattenContiguousRowMajorTransferReadPattern
// affine.delinearize_index op, we will form the new indices using
// affine.linearize_index else we use affine.apply/map to form the same.
SmallVector<Value> collapsedIndices;
if (hasDelinearizeIndexVal(transferReadOp.getIndices())) {
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, transferReadOp.getIndices(), sourceType.getShape(), true);
collapsedIndices.push_back(linearizedIndices);
} else {
FailureOr<SmallVector<OpFoldResult>> maybeBasis =
hasDelinearizeIndexVal(transferReadOp.getIndices());
if (failed(maybeBasis)) {
collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferReadOp.getIndices(), firstDimToCollapse);
} else {
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, transferReadOp.getIndices(), *maybeBasis, true);
collapsedIndices.push_back(linearizedIndices);
}

// 3. Create new vector.transfer_read that reads from the collapsed memref
Expand Down Expand Up @@ -301,14 +314,16 @@ class FlattenContiguousRowMajorTransferWritePattern
// affine.delinearize_index op, we will form the new indices using
// affine.linearize_index else we use affine.apply/map to form the same.
SmallVector<Value> collapsedIndices;
if (hasDelinearizeIndexVal(transferWriteOp.getIndices())) {
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, transferWriteOp.getIndices(), sourceType.getShape(), true);
collapsedIndices.push_back(linearizedIndices);
} else {
FailureOr<SmallVector<OpFoldResult>> maybeBasis =
hasDelinearizeIndexVal(transferWriteOp.getIndices());
if (failed(maybeBasis)) {
collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferWriteOp.getIndices(), firstDimToCollapse);
} else {
Value linearizedIndices = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, transferWriteOp.getIndices(), *maybeBasis, true);
collapsedIndices.push_back(linearizedIndices);
}

// 3. Create new vector.transfer_write that writes to the collapsed memref
Expand Down

0 comments on commit 7a6321f

Please sign in to comment.