Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Rework torch.repeat to not broadcast unary case #4061

Merged
merged 9 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,37 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
return success();
}

// Flatten size-1 broadcast dims to simplify the final generic op.
// If all dims are size-1 broadcast dims, then this will collapse to a
// rank-0 tensor.
SmallVector<ReassociationIndices> collapseExprs;
for (int64_t i = 0, e = inputRank; i < e; ++i) {
if (!broadcastedStatus[i]) {
collapseExprs.push_back({});
}
}

int64_t previous = -1;
bool collapse = false;
SmallVector<AffineExpr> inputExprs;
for (int64_t i = 0, e = inputRank; i < e; ++i) {
if (broadcastedStatus[i]) {
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
if (!broadcastedStatus[i]) {
previous++;
collapseExprs[previous].push_back(i);
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
continue;
}
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));

int64_t clamped = previous < 0 ? 0 : previous;
if (!collapseExprs.empty()) {
collapseExprs[clamped].push_back(i);
}
collapse = true;
}

if (collapse) {
input = rewriter.create<tensor::CollapseShapeOp>(op->getLoc(), input,
collapseExprs);
}

SmallVector<AffineMap> indexingMaps = {
Expand Down
130 changes: 67 additions & 63 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4426,88 +4426,92 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "input sizes unknown");

// Materialize out 1 dimensions to broadcast along. This includes
// materializing out preceding batch dimensions:
for (int i = 0; i < repeatSz; ++i) {
auto oldSizes = selfTy.getSizes();
llvm::SmallVector<int64_t> sizes;
int64_t squeezeDim = i < batch ? i : i * 2 - batch;
// Fold the constant values so that we know which we materialize:
llvm::SmallVector<int64_t> repeatInts;
for (int i = 0, s = repeats.size(); i < s; ++i) {
int64_t repeat;
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
repeat = Torch::kUnknownSize;

for (int j = 0; j < squeezeDim; ++j)
sizes.push_back(oldSizes[j]);
sizes.push_back(1);
for (int j = squeezeDim, s = oldSizes.size(); j < s; j++)
sizes.push_back(oldSizes[j]);
repeatInts.push_back(repeat);
}

// Unsqueeze all newly created dims
llvm::SmallVector<int> unsqueezeDims;
for (int i = 0; i < batch; ++i) {
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
self = *unsqueezeTensor(rewriter, op, self, iv);
selfTy = cast<ValueTensorType>(self.getType());
unsqueezeDims.push_back(i);
}

Value dim = rewriter.create<Torch::ConstantIntOp>(loc, squeezeDim);
selfTy =
rewriter.getType<ValueTensorType>(sizes, selfTy.getOptionalDtype());
self = rewriter.create<AtenUnsqueezeOp>(loc, selfTy, self, dim);
// Unsqueeze any non-unary repeats for existing dims
for (int i = batch, s = repeats.size(); i < s; ++i) {
if (repeatInts[i] == 1)
continue;
int64_t dim = i + unsqueezeDims.size() - batch;
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
self = *unsqueezeTensor(rewriter, op, self, iv);
selfTy = cast<ValueTensorType>(self.getType());
unsqueezeDims.push_back(dim);
}

// Materialize the expansion sizes for each dim:
llvm::SmallVector<Value> lengths;
for (int i = 0; i < repeatSz; ++i) {
if (i < batch) {
llvm::SmallVector<int64_t> expandShape;
for (int i = 0; i < batch; ++i) {
lengths.push_back(repeats[i]);
expandShape.push_back(repeatInts[i]);
}

for (int i = batch, s = repeats.size(); i < s; ++i) {
if (repeatInts[i] != 1) {
lengths.push_back(repeats[i]);
continue;
expandShape.push_back(repeatInts[i]);
}

Value iv = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch));
Value dim = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
lengths.push_back(repeats[i]);
lengths.push_back(dim);
int dim = lengths.size();
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
Value dimV = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
lengths.push_back(dimV);
expandShape.push_back(selfTy.getSizes()[dim]);
}

// Materialize the broadcast:
Value lengthv = rewriter.create<PrimListConstructOp>(
loc, ListType::get(rewriter.getType<IntType>()), lengths);
selfTy = rewriter.getType<ValueTensorType>(expandShape,
selfTy.getOptionalDtype());
self = rewriter.create<AtenBroadcastToOp>(loc, selfTy, self, lengthv);

llvm::SmallVector<int64_t> expandShape(selfTy.getSizes());
for (int i = 0; i < repeatSz; ++i) {
int64_t repeatDim = i < batch ? i : i * 2 - batch;
int64_t repeat;
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
repeat = Torch::kUnknownSize;
expandShape[repeatDim] = repeat;
}
auto outShape = cast<ValueTensorType>(op.getResult().getType()).getSizes();
for (int i = batch, s = repeats.size(); i < s; ++i) {
if (repeatInts[i] == 1)
continue;

auto mulDim = [](int64_t lhs, int64_t rhs) {
if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize)
return Torch::kUnknownSize;
return lhs * rhs;
};
auto selfShape = selfTy.getSizes();
llvm::SmallVector<int64_t> flattenShape;
for (int j = 0; j <= i; ++j)
flattenShape.push_back(outShape[j]);

BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
expandShape, selfTy.getOptionalDtype());
Value expand =
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, lengthv);
for (int j = i + 2, s = selfShape.size(); j < s; ++j)
flattenShape.push_back(selfShape[j]);

for (int i = 0; i < rank; ++i) {
auto oldShape = expandTy.getSizes();
llvm::SmallVector<int64_t> newShape;
int64_t flattenDim = i + batch;
for (int j = 0; j < flattenDim; ++j)
newShape.push_back(oldShape[j]);
newShape.push_back(
mulDim(oldShape[flattenDim], oldShape[flattenDim + 1]));
for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j)
newShape.push_back(oldShape[j]);

expandTy = rewriter.getType<ValueTensorType>(newShape,
expandTy.getOptionalDtype());

// Used to keep the return type the same on the last flatten:
expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType());

Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenDim));
selfTy = rewriter.getType<ValueTensorType>(flattenShape,
selfTy.getOptionalDtype());
Value start =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenDim + 1));
expand = rewriter.create<AtenFlattenUsingIntsOp>(loc, expandTy, expand,
start, end);
loc, rewriter.getI64IntegerAttr(i + 1));

self = rewriter.create<AtenFlattenUsingIntsOp>(loc, selfTy, self, start,
end);
}

rewriter.replaceOp(op, expand);
rewriter.replaceOp(op, self);
return success();
}
};
Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/TorchToLinalg/broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32

// CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast(
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32>
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<1x4x2xf32>
Expand Down
Loading