From f5b8fdc3ce35e97543a7393483dd5cbe6a3a59e4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 27 Feb 2025 17:21:41 -0800 Subject: [PATCH 1/9] [torch] Rework `torch.repeat` to not broadcast unary case Not all dimensions in `torch.repeat` may need to be broadcasted. Skip unsqueezing and flattening these dimensions together. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 137 +++++++++++------- 1 file changed, 81 insertions(+), 56 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1226ad2c03e2..af529da906c3 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4426,50 +4426,80 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "input sizes unknown"); - // Materialize out 1 dimensions to broadcast along. This includes - // materializing out preceding batch dimensions: - for (int i = 0; i < repeatSz; ++i) { - auto oldSizes = selfTy.getSizes(); - llvm::SmallVector sizes; - int64_t squeezeDim = i < batch ? i : i * 2 - batch; + auto unsqueeze = [&](Value input, int64_t dim) { + auto inputTy = input.getType().cast(); + auto oldSizes = inputTy.getSizes(); - for (int j = 0; j < squeezeDim; ++j) + llvm::SmallVector sizes; + for (int j = 0; j < dim; ++j) sizes.push_back(oldSizes[j]); sizes.push_back(1); - for (int j = squeezeDim, s = oldSizes.size(); j < s; j++) + for (int j = dim, s = oldSizes.size(); j < s; j++) sizes.push_back(oldSizes[j]); - Value dim = rewriter.create(loc, squeezeDim); - selfTy = - rewriter.getType(sizes, selfTy.getOptionalDtype()); - self = rewriter.create(loc, selfTy, self, dim); + Value dimVal = rewriter.create(loc, dim); + inputTy = + rewriter.getType(sizes, inputTy.getOptionalDtype()); + input = rewriter.create(loc, inputTy, input, dimVal); + + return input; + }; + + // Fold the constant values so that we know which we materialize: + llvm::SmallVector repeatInts; + for (int i = 0, s = repeats.size(); i < s; ++i) { + int64_t repeat; + if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) + repeat = Torch::kUnknownSize; + + repeatInts.push_back(repeat); } + // Unsqueeze any non-unary repeats: + llvm::SmallVector unsqueezeDims; + for (int i = 0; i < batch; ++i) { + self = unsqueeze(self, i); + selfTy = self.getType().cast(); + unsqueezeDims.push_back(i); + } + + for (int i = batch, s = repeats.size(); i < s; ++i) { + if (repeatInts[i] == 1) + continue; + int64_t dim = i + unsqueezeDims.size() - batch; + self = unsqueeze(self, dim); + selfTy = self.getType().cast(); + unsqueezeDims.push_back(dim); + } + + // Materialize the expansion sizes for each dim: llvm::SmallVector lengths; - for (int i = 0; i < repeatSz; ++i) { - if (i < batch) { + llvm::SmallVector expandShape; + for (int i = 0; i < batch; ++i) { + lengths.push_back(repeats[i]); + expandShape.push_back(repeatInts[i]); + } + + for (int i = batch, s = repeats.size(); i < s; ++i) { + if (repeatInts[i] != 1) { lengths.push_back(repeats[i]); - continue; + expandShape.push_back(repeatInts[i]); } - Value iv = rewriter.create( - loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch)); - Value dim = rewriter.create(loc, self, /*dim=*/iv); - lengths.push_back(repeats[i]); - lengths.push_back(dim); + int dim = lengths.size(); + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + Value dimV = rewriter.create(loc, self, /*dim=*/iv); + lengths.push_back(dimV); + expandShape.push_back(selfTy.getSizes()[dim]); } + // Materialize the broadcast: Value lengthv = rewriter.create( loc, ListType::get(rewriter.getType()), lengths); - - llvm::SmallVector expandShape(selfTy.getSizes()); - for (int i = 0; i < repeatSz; ++i) { - int64_t repeatDim = i < batch ? i : i * 2 - batch; - int64_t repeat; - if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) - repeat = Torch::kUnknownSize; - expandShape[repeatDim] = repeat; - } + selfTy = rewriter.getType(expandShape, + selfTy.getOptionalDtype()); + self = rewriter.create(loc, selfTy, self, lengthv); auto mulDim = [](int64_t lhs, int64_t rhs) { if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize) @@ -4477,37 +4507,32 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { return lhs * rhs; }; - BaseTensorType expandTy = rewriter.getType( - expandShape, selfTy.getOptionalDtype()); - Value expand = - rewriter.create(loc, expandTy, self, lengthv); + for (int i = batch, s = repeats.size(); i < s; ++i) { + if (repeatInts[i] == 1) + continue; - for (int i = 0; i < rank; ++i) { - auto oldShape = expandTy.getSizes(); - llvm::SmallVector newShape; - int64_t flattenDim = i + batch; - for (int j = 0; j < flattenDim; ++j) - newShape.push_back(oldShape[j]); - newShape.push_back( - mulDim(oldShape[flattenDim], oldShape[flattenDim + 1])); - for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j) - newShape.push_back(oldShape[j]); - - expandTy = rewriter.getType(newShape, - expandTy.getOptionalDtype()); - - // Used to keep the return type the same on the last flatten: - expandTy = i < rank - 1 ? expandTy : cast(op.getType()); - - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(flattenDim)); + auto selfShape = selfTy.getSizes(); + llvm::SmallVector flattenShape; + for (int j = 0; j < i; ++j) + flattenShape.push_back(selfTy.getSizes()[j]); + + flattenShape.push_back(mulDim(selfShape[i], selfShape[i + 1])); + + for (int j = i + 2, s = selfShape.size(); j < s; ++j) + flattenShape.push_back(selfShape[j]); + + selfTy = rewriter.getType(flattenShape, + selfTy.getOptionalDtype()); + Value start = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(flattenDim + 1)); - expand = rewriter.create(loc, expandTy, expand, - start, end); + loc, rewriter.getI64IntegerAttr(i + 1)); + + self = rewriter.create(loc, selfTy, self, start, + end); } - rewriter.replaceOp(op, expand); + rewriter.replaceOp(op, self); return success(); } }; From f3fb393093f9bbef7d8ba8a6b4965083f113e75c Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 27 Feb 2025 17:30:48 -0800 Subject: [PATCH 2/9] fix for cast deprecation --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index af529da906c3..1644b923abd2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4427,7 +4427,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "input sizes unknown"); auto unsqueeze = [&](Value input, int64_t dim) { - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto oldSizes = inputTy.getSizes(); llvm::SmallVector sizes; @@ -4459,7 +4459,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { llvm::SmallVector unsqueezeDims; for (int i = 0; i < batch; ++i) { self = unsqueeze(self, i); - selfTy = self.getType().cast(); + selfTy = cast(self.getType()); unsqueezeDims.push_back(i); } @@ -4468,7 +4468,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { continue; int64_t dim = i + unsqueezeDims.size() - batch; self = unsqueeze(self, dim); - selfTy = self.getType().cast(); + selfTy = cast(self.getType()); unsqueezeDims.push_back(dim); } From f2f3d620c19296a6ad6c546bff7134f48b78f4e5 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 27 Feb 2025 20:44:26 -0800 Subject: [PATCH 3/9] fix case for shape inference --- .../Torch/Transforms/DecomposeComplexOps.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1644b923abd2..d89eb35d475b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4501,22 +4501,15 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { selfTy.getOptionalDtype()); self = rewriter.create(loc, selfTy, self, lengthv); - auto mulDim = [](int64_t lhs, int64_t rhs) { - if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize) - return Torch::kUnknownSize; - return lhs * rhs; - }; - + auto outShape = cast(op.getResult().getType()).getSizes(); for (int i = batch, s = repeats.size(); i < s; ++i) { if (repeatInts[i] == 1) continue; auto selfShape = selfTy.getSizes(); llvm::SmallVector flattenShape; - for (int j = 0; j < i; ++j) - flattenShape.push_back(selfTy.getSizes()[j]); - - flattenShape.push_back(mulDim(selfShape[i], selfShape[i + 1])); + for (int j = 0; j <= i; ++j) + flattenShape.push_back(outShape[j]); for (int j = i + 2, s = selfShape.size(); j < s; ++j) flattenShape.push_back(selfShape[j]); From ad06c3a33797a4f1c288b778f7b2882b298876b1 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 28 Feb 2025 17:05:15 -0800 Subject: [PATCH 4/9] add some broadcast changes too --- lib/Conversion/TorchToLinalg/Utils.cpp | 27 ++++++++++++++--- .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++-------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 98dbc1957892..5eec36b55432 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -457,13 +457,32 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } + SmallVector collapseExprs; + for (int64_t i = 0, e = inputRank; i < e; ++i) { + if (!broadcastedStatus[i]) { + collapseExprs.push_back({}); + } + } + + int64_t previous = -1; + for (int64_t i = 0, e = inputRank; i < e; ++i) { + if (!broadcastedStatus[i]) { + previous++; + collapseExprs[previous].push_back(i + diff); + } else { + int64_t clamped = previous < 0 ? 0 : previous; + collapseExprs[clamped].push_back(i + diff); + } + } + + input = rewriter.create(op->getLoc(), input, + collapseExprs); + SmallVector inputExprs; for (int64_t i = 0, e = inputRank; i < e; ++i) { - if (broadcastedStatus[i]) { - inputExprs.push_back(rewriter.getAffineConstantExpr(0)); - continue; + if (!broadcastedStatus[i]) { + inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); } - inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); } SmallVector indexingMaps = { diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d89eb35d475b..57f8c0c09484 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4426,25 +4426,6 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "input sizes unknown"); - auto unsqueeze = [&](Value input, int64_t dim) { - auto inputTy = cast(input.getType()); - auto oldSizes = inputTy.getSizes(); - - llvm::SmallVector sizes; - for (int j = 0; j < dim; ++j) - sizes.push_back(oldSizes[j]); - sizes.push_back(1); - for (int j = dim, s = oldSizes.size(); j < s; j++) - sizes.push_back(oldSizes[j]); - - Value dimVal = rewriter.create(loc, dim); - inputTy = - rewriter.getType(sizes, inputTy.getOptionalDtype()); - input = rewriter.create(loc, inputTy, input, dimVal); - - return input; - }; - // Fold the constant values so that we know which we materialize: llvm::SmallVector repeatInts; for (int i = 0, s = repeats.size(); i < s; ++i) { @@ -4455,19 +4436,24 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { repeatInts.push_back(repeat); } - // Unsqueeze any non-unary repeats: + // Unsqueeze all newly created dims llvm::SmallVector unsqueezeDims; for (int i = 0; i < batch; ++i) { - self = unsqueeze(self, i); + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + self = *unsqueezeTensor(rewriter, op, self, iv); selfTy = cast(self.getType()); unsqueezeDims.push_back(i); } + // Unsqueeze any non-unary repeats for existing dims for (int i = batch, s = repeats.size(); i < s; ++i) { if (repeatInts[i] == 1) continue; int64_t dim = i + unsqueezeDims.size() - batch; - self = unsqueeze(self, dim); + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + self = *unsqueezeTensor(rewriter, op, self, iv); selfTy = cast(self.getType()); unsqueezeDims.push_back(dim); } From bdae8c44f851b0985ea61419b0f3054e9742596e Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 28 Feb 2025 17:24:23 -0800 Subject: [PATCH 5/9] make collapse conditional --- lib/Conversion/TorchToLinalg/Utils.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 5eec36b55432..3c8c27381deb 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -475,8 +475,10 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } } - input = rewriter.create(op->getLoc(), input, - collapseExprs); + if (collapseExprs.size() < inputRank) { + input = rewriter.create(op->getLoc(), input, + collapseExprs); + } SmallVector inputExprs; for (int64_t i = 0, e = inputRank; i < e; ++i) { From bf7c1c708e450eabe8e6b0d3461a915934b48f10 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 28 Feb 2025 17:30:25 -0800 Subject: [PATCH 6/9] fix comparison failure --- lib/Conversion/TorchToLinalg/Utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 3c8c27381deb..aa13dacc3378 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -475,7 +475,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } } - if (collapseExprs.size() < inputRank) { + if (collapseExprs.size() < static_cast(inputRank)) { input = rewriter.create(op->getLoc(), input, collapseExprs); } From 04762685da8fa8d180e5569bd19932952e39e0fc Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 28 Feb 2025 17:43:18 -0800 Subject: [PATCH 7/9] fix checks --- test/Conversion/TorchToLinalg/broadcast.mlir | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/Conversion/TorchToLinalg/broadcast.mlir b/test/Conversion/TorchToLinalg/broadcast.mlir index 8841ba704328..4d3f7194bd84 100644 --- a/test/Conversion/TorchToLinalg/broadcast.mlir +++ b/test/Conversion/TorchToLinalg/broadcast.mlir @@ -22,10 +22,11 @@ func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32 // CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast( // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32> +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32> // CHECK: %[[GENERIC:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} -// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) { +// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) { // CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): // CHECK: linalg.yield %[[IN]] : f32 // CHECK: } -> tensor<1x4x2xf32> From 82f7f77039fb6ae396c4380b94d236701b7530bd Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 28 Feb 2025 18:23:55 -0800 Subject: [PATCH 8/9] fix test failures --- lib/Conversion/TorchToLinalg/Utils.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index aa13dacc3378..456ea855bf46 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -465,17 +465,22 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } int64_t previous = -1; + bool collapse = false; for (int64_t i = 0, e = inputRank; i < e; ++i) { if (!broadcastedStatus[i]) { previous++; - collapseExprs[previous].push_back(i + diff); - } else { - int64_t clamped = previous < 0 ? 0 : previous; - collapseExprs[clamped].push_back(i + diff); + collapseExprs[previous].push_back(i); + continue; + } + + int64_t clamped = previous < 0 ? 0 : previous; + if (!collapseExprs.empty()) { + collapseExprs[clamped].push_back(i); } + collapse = true; } - if (collapseExprs.size() < static_cast(inputRank)) { + if (collapse) { input = rewriter.create(op->getLoc(), input, collapseExprs); } From 563d84f79ae3bf5cebfb227a189d69ad50f17fef Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 3 Mar 2025 16:01:58 -0800 Subject: [PATCH 9/9] Code review comments --- lib/Conversion/TorchToLinalg/Utils.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 456ea855bf46..f1f58367da3a 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -457,6 +457,9 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } + // Flatten size-1 broadcast dims to simplify the final generic op. + // If all dims are size-1 broadcast dims, then this will collapse to a + // rank-0 tensor. SmallVector collapseExprs; for (int64_t i = 0, e = inputRank; i < e; ++i) { if (!broadcastedStatus[i]) { @@ -466,10 +469,12 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( int64_t previous = -1; bool collapse = false; + SmallVector inputExprs; for (int64_t i = 0, e = inputRank; i < e; ++i) { if (!broadcastedStatus[i]) { previous++; collapseExprs[previous].push_back(i); + inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); continue; } @@ -485,13 +490,6 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( collapseExprs); } - SmallVector inputExprs; - for (int64_t i = 0, e = inputRank; i < e; ++i) { - if (!broadcastedStatus[i]) { - inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); - } - } - SmallVector indexingMaps = { AffineMap::get(outputRank, 0, inputExprs, rewriter.getContext()), rewriter.getMultiDimIdentityMap(outputRank)};