diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 98dbc1957892..f1f58367da3a 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -457,13 +457,37 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( return success(); } + // Flatten size-1 broadcast dims to simplify the final generic op. + // If all dims are size-1 broadcast dims, then this will collapse to a + // rank-0 tensor. + SmallVector collapseExprs; + for (int64_t i = 0, e = inputRank; i < e; ++i) { + if (!broadcastedStatus[i]) { + collapseExprs.push_back({}); + } + } + + int64_t previous = -1; + bool collapse = false; SmallVector inputExprs; for (int64_t i = 0, e = inputRank; i < e; ++i) { - if (broadcastedStatus[i]) { - inputExprs.push_back(rewriter.getAffineConstantExpr(0)); + if (!broadcastedStatus[i]) { + previous++; + collapseExprs[previous].push_back(i); + inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); continue; } - inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); + + int64_t clamped = previous < 0 ? 0 : previous; + if (!collapseExprs.empty()) { + collapseExprs[clamped].push_back(i); + } + collapse = true; + } + + if (collapse) { + input = rewriter.create(op->getLoc(), input, + collapseExprs); } SmallVector indexingMaps = { diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1226ad2c03e2..57f8c0c09484 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4426,88 +4426,92 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "input sizes unknown"); - // Materialize out 1 dimensions to broadcast along. This includes - // materializing out preceding batch dimensions: - for (int i = 0; i < repeatSz; ++i) { - auto oldSizes = selfTy.getSizes(); - llvm::SmallVector sizes; - int64_t squeezeDim = i < batch ? i : i * 2 - batch; + // Fold the constant values so that we know which we materialize: + llvm::SmallVector repeatInts; + for (int i = 0, s = repeats.size(); i < s; ++i) { + int64_t repeat; + if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) + repeat = Torch::kUnknownSize; - for (int j = 0; j < squeezeDim; ++j) - sizes.push_back(oldSizes[j]); - sizes.push_back(1); - for (int j = squeezeDim, s = oldSizes.size(); j < s; j++) - sizes.push_back(oldSizes[j]); + repeatInts.push_back(repeat); + } + + // Unsqueeze all newly created dims + llvm::SmallVector unsqueezeDims; + for (int i = 0; i < batch; ++i) { + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + self = *unsqueezeTensor(rewriter, op, self, iv); + selfTy = cast(self.getType()); + unsqueezeDims.push_back(i); + } - Value dim = rewriter.create(loc, squeezeDim); - selfTy = - rewriter.getType(sizes, selfTy.getOptionalDtype()); - self = rewriter.create(loc, selfTy, self, dim); + // Unsqueeze any non-unary repeats for existing dims + for (int i = batch, s = repeats.size(); i < s; ++i) { + if (repeatInts[i] == 1) + continue; + int64_t dim = i + unsqueezeDims.size() - batch; + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + self = *unsqueezeTensor(rewriter, op, self, iv); + selfTy = cast(self.getType()); + unsqueezeDims.push_back(dim); } + // Materialize the expansion sizes for each dim: llvm::SmallVector lengths; - for (int i = 0; i < repeatSz; ++i) { - if (i < batch) { + llvm::SmallVector expandShape; + for (int i = 0; i < batch; ++i) { + lengths.push_back(repeats[i]); + expandShape.push_back(repeatInts[i]); + } + + for (int i = batch, s = repeats.size(); i < s; ++i) { + if (repeatInts[i] != 1) { lengths.push_back(repeats[i]); - continue; + expandShape.push_back(repeatInts[i]); } - Value iv = rewriter.create( - loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch)); - Value dim = rewriter.create(loc, self, /*dim=*/iv); - lengths.push_back(repeats[i]); - lengths.push_back(dim); + int dim = lengths.size(); + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + Value dimV = rewriter.create(loc, self, /*dim=*/iv); + lengths.push_back(dimV); + expandShape.push_back(selfTy.getSizes()[dim]); } + // Materialize the broadcast: Value lengthv = rewriter.create( loc, ListType::get(rewriter.getType()), lengths); + selfTy = rewriter.getType(expandShape, + selfTy.getOptionalDtype()); + self = rewriter.create(loc, selfTy, self, lengthv); - llvm::SmallVector expandShape(selfTy.getSizes()); - for (int i = 0; i < repeatSz; ++i) { - int64_t repeatDim = i < batch ? i : i * 2 - batch; - int64_t repeat; - if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) - repeat = Torch::kUnknownSize; - expandShape[repeatDim] = repeat; - } + auto outShape = cast(op.getResult().getType()).getSizes(); + for (int i = batch, s = repeats.size(); i < s; ++i) { + if (repeatInts[i] == 1) + continue; - auto mulDim = [](int64_t lhs, int64_t rhs) { - if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize) - return Torch::kUnknownSize; - return lhs * rhs; - }; + auto selfShape = selfTy.getSizes(); + llvm::SmallVector flattenShape; + for (int j = 0; j <= i; ++j) + flattenShape.push_back(outShape[j]); - BaseTensorType expandTy = rewriter.getType( - expandShape, selfTy.getOptionalDtype()); - Value expand = - rewriter.create(loc, expandTy, self, lengthv); + for (int j = i + 2, s = selfShape.size(); j < s; ++j) + flattenShape.push_back(selfShape[j]); - for (int i = 0; i < rank; ++i) { - auto oldShape = expandTy.getSizes(); - llvm::SmallVector newShape; - int64_t flattenDim = i + batch; - for (int j = 0; j < flattenDim; ++j) - newShape.push_back(oldShape[j]); - newShape.push_back( - mulDim(oldShape[flattenDim], oldShape[flattenDim + 1])); - for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j) - newShape.push_back(oldShape[j]); - - expandTy = rewriter.getType(newShape, - expandTy.getOptionalDtype()); - - // Used to keep the return type the same on the last flatten: - expandTy = i < rank - 1 ? expandTy : cast(op.getType()); - - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(flattenDim)); + selfTy = rewriter.getType(flattenShape, + selfTy.getOptionalDtype()); + Value start = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(flattenDim + 1)); - expand = rewriter.create(loc, expandTy, expand, - start, end); + loc, rewriter.getI64IntegerAttr(i + 1)); + + self = rewriter.create(loc, selfTy, self, start, + end); } - rewriter.replaceOp(op, expand); + rewriter.replaceOp(op, self); return success(); } }; diff --git a/test/Conversion/TorchToLinalg/broadcast.mlir b/test/Conversion/TorchToLinalg/broadcast.mlir index 8841ba704328..4d3f7194bd84 100644 --- a/test/Conversion/TorchToLinalg/broadcast.mlir +++ b/test/Conversion/TorchToLinalg/broadcast.mlir @@ -22,10 +22,11 @@ func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32 // CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast( // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32> +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32> // CHECK: %[[GENERIC:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} -// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) { +// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) { // CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): // CHECK: linalg.yield %[[IN]] : f32 // CHECK: } -> tensor<1x4x2xf32>