Skip to content

Commit

Permalink
[torch] Rework torch.repeat to not broadcast unary case (#4061)
Browse files Browse the repository at this point in the history
Not all dimensions in `torch.repeat` may need to be broadcasted. Skip
unsqueezing and flattening these dimensions together.
  • Loading branch information
rsuderman authored Mar 4, 2025
1 parent 91a6c15 commit 0d0653a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 68 deletions.
30 changes: 27 additions & 3 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,37 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
return success();
}

// Flatten size-1 broadcast dims to simplify the final generic op.
// If all dims are size-1 broadcast dims, then this will collapse to a
// rank-0 tensor.
SmallVector<ReassociationIndices> collapseExprs;
for (int64_t i = 0, e = inputRank; i < e; ++i) {
if (!broadcastedStatus[i]) {
collapseExprs.push_back({});
}
}

int64_t previous = -1;
bool collapse = false;
SmallVector<AffineExpr> inputExprs;
for (int64_t i = 0, e = inputRank; i < e; ++i) {
if (broadcastedStatus[i]) {
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
if (!broadcastedStatus[i]) {
previous++;
collapseExprs[previous].push_back(i);
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
continue;
}
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));

int64_t clamped = previous < 0 ? 0 : previous;
if (!collapseExprs.empty()) {
collapseExprs[clamped].push_back(i);
}
collapse = true;
}

if (collapse) {
input = rewriter.create<tensor::CollapseShapeOp>(op->getLoc(), input,
collapseExprs);
}

SmallVector<AffineMap> indexingMaps = {
Expand Down
130 changes: 67 additions & 63 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4426,88 +4426,92 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "input sizes unknown");

// Materialize out 1 dimensions to broadcast along. This includes
// materializing out preceding batch dimensions:
for (int i = 0; i < repeatSz; ++i) {
auto oldSizes = selfTy.getSizes();
llvm::SmallVector<int64_t> sizes;
int64_t squeezeDim = i < batch ? i : i * 2 - batch;
// Fold the constant values so that we know which we materialize:
llvm::SmallVector<int64_t> repeatInts;
for (int i = 0, s = repeats.size(); i < s; ++i) {
int64_t repeat;
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
repeat = Torch::kUnknownSize;

for (int j = 0; j < squeezeDim; ++j)
sizes.push_back(oldSizes[j]);
sizes.push_back(1);
for (int j = squeezeDim, s = oldSizes.size(); j < s; j++)
sizes.push_back(oldSizes[j]);
repeatInts.push_back(repeat);
}

// Unsqueeze all newly created dims
llvm::SmallVector<int> unsqueezeDims;
for (int i = 0; i < batch; ++i) {
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
self = *unsqueezeTensor(rewriter, op, self, iv);
selfTy = cast<ValueTensorType>(self.getType());
unsqueezeDims.push_back(i);
}

Value dim = rewriter.create<Torch::ConstantIntOp>(loc, squeezeDim);
selfTy =
rewriter.getType<ValueTensorType>(sizes, selfTy.getOptionalDtype());
self = rewriter.create<AtenUnsqueezeOp>(loc, selfTy, self, dim);
// Unsqueeze any non-unary repeats for existing dims
for (int i = batch, s = repeats.size(); i < s; ++i) {
if (repeatInts[i] == 1)
continue;
int64_t dim = i + unsqueezeDims.size() - batch;
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
self = *unsqueezeTensor(rewriter, op, self, iv);
selfTy = cast<ValueTensorType>(self.getType());
unsqueezeDims.push_back(dim);
}

// Materialize the expansion sizes for each dim:
llvm::SmallVector<Value> lengths;
for (int i = 0; i < repeatSz; ++i) {
if (i < batch) {
llvm::SmallVector<int64_t> expandShape;
for (int i = 0; i < batch; ++i) {
lengths.push_back(repeats[i]);
expandShape.push_back(repeatInts[i]);
}

for (int i = batch, s = repeats.size(); i < s; ++i) {
if (repeatInts[i] != 1) {
lengths.push_back(repeats[i]);
continue;
expandShape.push_back(repeatInts[i]);
}

Value iv = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch));
Value dim = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
lengths.push_back(repeats[i]);
lengths.push_back(dim);
int dim = lengths.size();
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
Value dimV = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
lengths.push_back(dimV);
expandShape.push_back(selfTy.getSizes()[dim]);
}

// Materialize the broadcast:
Value lengthv = rewriter.create<PrimListConstructOp>(
loc, ListType::get(rewriter.getType<IntType>()), lengths);
selfTy = rewriter.getType<ValueTensorType>(expandShape,
selfTy.getOptionalDtype());
self = rewriter.create<AtenBroadcastToOp>(loc, selfTy, self, lengthv);

llvm::SmallVector<int64_t> expandShape(selfTy.getSizes());
for (int i = 0; i < repeatSz; ++i) {
int64_t repeatDim = i < batch ? i : i * 2 - batch;
int64_t repeat;
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
repeat = Torch::kUnknownSize;
expandShape[repeatDim] = repeat;
}
auto outShape = cast<ValueTensorType>(op.getResult().getType()).getSizes();
for (int i = batch, s = repeats.size(); i < s; ++i) {
if (repeatInts[i] == 1)
continue;

auto mulDim = [](int64_t lhs, int64_t rhs) {
if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize)
return Torch::kUnknownSize;
return lhs * rhs;
};
auto selfShape = selfTy.getSizes();
llvm::SmallVector<int64_t> flattenShape;
for (int j = 0; j <= i; ++j)
flattenShape.push_back(outShape[j]);

BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
expandShape, selfTy.getOptionalDtype());
Value expand =
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, lengthv);
for (int j = i + 2, s = selfShape.size(); j < s; ++j)
flattenShape.push_back(selfShape[j]);

for (int i = 0; i < rank; ++i) {
auto oldShape = expandTy.getSizes();
llvm::SmallVector<int64_t> newShape;
int64_t flattenDim = i + batch;
for (int j = 0; j < flattenDim; ++j)
newShape.push_back(oldShape[j]);
newShape.push_back(
mulDim(oldShape[flattenDim], oldShape[flattenDim + 1]));
for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j)
newShape.push_back(oldShape[j]);

expandTy = rewriter.getType<ValueTensorType>(newShape,
expandTy.getOptionalDtype());

// Used to keep the return type the same on the last flatten:
expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType());

Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenDim));
selfTy = rewriter.getType<ValueTensorType>(flattenShape,
selfTy.getOptionalDtype());
Value start =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenDim + 1));
expand = rewriter.create<AtenFlattenUsingIntsOp>(loc, expandTy, expand,
start, end);
loc, rewriter.getI64IntegerAttr(i + 1));

self = rewriter.create<AtenFlattenUsingIntsOp>(loc, selfTy, self, start,
end);
}

rewriter.replaceOp(op, expand);
rewriter.replaceOp(op, self);
return success();
}
};
Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/TorchToLinalg/broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32

// CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast(
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32>
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<1x4x2xf32>
Expand Down

0 comments on commit 0d0653a

Please sign in to comment.