Skip to content

Commit a9b80ee

Browse files
author
Ivan Garcia
committed
Add support for transposed grouped convolution in torch to linalg lowering
1 parent 6ea4d11 commit a9b80ee

File tree

4 files changed

+178
-31
lines changed

4 files changed

+178
-31
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,50 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
955955
if (isa<mlir::IntegerType>(inputDTy))
956956
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
957957
}
958+
959+
// This code was moved earlier because in the grouped transposed convolution
960+
// case we need to expand before doing the dimension permutation. For the
961+
// grouped non-transposed convolution, we don't need to do filter/channel
962+
// dimension flipping, we can just expand the group from the filter in place
963+
// to have the group dimension in front:
964+
// expand F,C,H,W -> G,F/G,C,H,W
965+
//
966+
// When we have grouped transposed convolution we need to first expand the
967+
// input channel: expand C,F,H,W -> G,C/G,F,H,W
968+
//
969+
// And then flip the output filters with the input channel to make it linalg
970+
// compatible: permute G,C/G,F,H,W -> G,F,C/G,H,W
971+
//
972+
// Notice that if the flipping happens first, then we can't move the group
973+
// dimension to the front as the linalg convolution operation requires.
974+
//
975+
auto expandWeight = [&](Value tensor) {
976+
auto inType = cast<RankedTensorType>(tensor.getType());
977+
auto inShape = makeShapeTorchCompatible(inType.getShape());
978+
979+
SmallVector<int64_t> outShape{numGroups,
980+
(inShape[0] == kUnknownSize
981+
? kUnknownSize
982+
: (inShape[0] / numGroups)),
983+
inShape[1]};
984+
outShape.append(inShape.begin() + 2, inShape.end());
985+
986+
SmallVector<ReassociationIndices> indices{};
987+
int currIndex = 0;
988+
indices.push_back({0, 1});
989+
currIndex += 2;
990+
for (int i = currIndex; i <= (long)inShape.size(); i++)
991+
indices.push_back({i});
992+
993+
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
994+
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
995+
indices);
996+
};
997+
958998
if (transposed) {
999+
bool isGroupedConv = numGroups > 1;
1000+
weight = isGroupedConv ? expandWeight(weight) : weight;
1001+
9591002
Value c0 =
9601003
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
9611004
Value c1 =
@@ -965,25 +1008,40 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
9651008

9661009
// Transpose and flip weight
9671010
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
968-
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
969-
outDims[1] = weightInitDims[0];
1011+
if (isGroupedConv) {
1012+
// We need to skip the first dimension (group) in this case, also the
1013+
// output dimension needs to consider the number of groups.
1014+
std::iter_swap(weightInitDims.begin() + 1, weightInitDims.begin() + 2);
1015+
auto numGroupsVal =
1016+
rewriter.create<mlir::arith::ConstantIndexOp>(loc, numGroups);
1017+
outDims[1] = rewriter.createOrFold<mlir::arith::MulIOp>(
1018+
loc, weightInitDims[1], numGroupsVal);
1019+
} else {
1020+
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
1021+
outDims[1] = weightInitDims[0];
1022+
}
1023+
auto weightRank = weightInitDims.size();
9701024
Value weightInitTensor =
9711025
createZeroInitTensor(rewriter, loc, weightInitDims, weightDTy);
9721026
SmallVector<utils::IteratorType> iteratorTypes(
973-
inRank, utils::IteratorType::parallel);
1027+
weightRank, utils::IteratorType::parallel);
9741028
SmallVector<AffineMap> indexingMaps{
975-
AffineMap::getMultiDimIdentityMap(inRank, context)};
1029+
AffineMap::getMultiDimIdentityMap(weightRank, context)};
9761030
weight = rewriter
9771031
.create<linalg::GenericOp>(
9781032
loc, weightInitTensor.getType(), ValueRange{},
9791033
weightInitTensor, indexingMaps, iteratorTypes,
9801034
[&](OpBuilder &b, Location loc, ValueRange args) {
9811035
SmallVector<Value> indices;
982-
for (size_t i = 0; i < inRank; i++)
1036+
for (size_t i = 0; i < weightRank; i++)
9831037
indices.push_back(b.create<linalg::IndexOp>(loc, i));
984-
std::iter_swap(indices.begin(), indices.begin() + 1);
985-
// Flip only the spatial dimensions (from 2 to inRank)
986-
for (size_t flipDim = 2; flipDim < inRank; flipDim++) {
1038+
auto fcIdxSwapOffset = isGroupedConv ? 1 : 0;
1039+
std::iter_swap(indices.begin() + fcIdxSwapOffset,
1040+
indices.begin() + fcIdxSwapOffset + 1);
1041+
// Flip only the spatial dimensions (from 2 to
1042+
// weightRank)
1043+
for (size_t flipDim = fcIdxSwapOffset + 2;
1044+
flipDim < weightRank; flipDim++) {
9871045
indices[flipDim] = b.create<arith::SubIOp>(
9881046
loc,
9891047
b.create<arith::SubIOp>(
@@ -1373,43 +1431,26 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13731431
indices);
13741432
};
13751433

1376-
// expand F,C,H,W -> G,F/G,C,H,W
1377-
auto expandWeight = [&](Value tensor) {
1378-
auto inType = cast<RankedTensorType>(tensor.getType());
1379-
auto inShape = makeShapeTorchCompatible(inType.getShape());
1380-
1381-
SmallVector<int64_t> outShape{
1382-
numGroups,
1383-
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)};
1384-
outShape.append(inShape.begin() + 1, inShape.end());
1385-
1386-
SmallVector<ReassociationIndices> indices{{0, 1}};
1387-
for (auto i = 2; i <= (long)inShape.size(); i++)
1388-
indices.push_back({i});
1389-
1390-
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
1391-
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
1392-
indices);
1393-
};
1394-
13951434
Value paddedInputExpanded = expandGroups(paddedInput, 1);
1396-
Value weightExpanded = expandWeight(weight);
1435+
// If we have a transposed convolution, this needs to be handled before
1436+
// dimension permutation. See comments in the expandWeight lambda definition
1437+
// for details.
1438+
weight = transposed ? weight : expandWeight(weight);
13971439
auto expandOutputTensor = expandGroups(outputTensor, 1);
13981440

13991441
// TODO: add 1D and 3D case
14001442
if (!inputZp) {
14011443
conv = rewriter
14021444
.create<linalg::Conv2DNgchwGfchwOp>(
14031445
loc, expandOutputTensor.getResultType(),
1404-
ValueRange{paddedInputExpanded, weightExpanded},
1446+
ValueRange{paddedInputExpanded, weight},
14051447
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
14061448
.getResult(0);
14071449
} else {
14081450
conv = rewriter
14091451
.create<linalg::Conv2DNgchwGfchwQOp>(
14101452
loc, expandOutputTensor.getResultType(),
1411-
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
1412-
weightZp},
1453+
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
14131454
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
14141455
.getResult(0);
14151456
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3523,6 +3523,7 @@
35233523
"ConvolutionModule2DTransposeStridedStatic_basic",
35243524
"ConvolutionModule2DTransposeStrided_basic",
35253525
"ConvolutionModule2DTranspose_basic",
3526+
"ConvolutionModule2DGroupedTranspose_basic",
35263527
"CumsumInputDtypeInt32Module_basic",
35273528
"CumsumModule_basic",
35283529
"CumsumStaticModule_basic",
@@ -4099,6 +4100,7 @@
40994100
"ConvolutionModule2DTransposeStridedStatic_basic",
41004101
"ConvolutionModule2DTransposeStrided_basic",
41014102
"ConvolutionModule2DTranspose_basic",
4103+
"ConvolutionModule2DGroupedTranspose_basic",
41024104
"CopyModule_basic",
41034105
"CopyWithDifferentDTypesAndSizesModule_basic",
41044106
"CopyWithDifferentDTypesModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,3 +1725,35 @@ def DeformConv2D_basic(module, tu: TestUtils):
17251725
offset = tu.rand(N, offset_dim1, Hout, Wout)
17261726
weight = tu.rand(Cout, Cin, Hker, Wker)
17271727
module.forward(input, offset, weight)
1728+
1729+
1730+
class ConvolutionModule2DGroupedTranspose(torch.nn.Module):
1731+
def __init__(self):
1732+
super().__init__()
1733+
1734+
@export
1735+
@annotate_args(
1736+
[
1737+
None,
1738+
([1, 2, 5, 7], torch.float32, True),
1739+
([2, 2, 3, 3], torch.float32, True),
1740+
([4], torch.float32, True),
1741+
]
1742+
)
1743+
def forward(self, inputVec, weight, bias):
1744+
return torch.ops.aten.convolution(
1745+
inputVec,
1746+
weight,
1747+
bias=bias,
1748+
stride=[2, 2],
1749+
padding=[1, 1],
1750+
dilation=[1, 1],
1751+
transposed=True,
1752+
output_padding=[0, 0],
1753+
groups=2,
1754+
)
1755+
1756+
@register_test_case(module_factory=lambda: ConvolutionModule2DGroupedTranspose())
1757+
def ConvolutionModule2DGroupedTranspose_basic(module, tu: TestUtils):
1758+
module.forward(tu.rand(1, 2, 5, 7), tu.rand(2, 2, 3, 3), tu.rand(4))
1759+

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,75 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.
7676
%2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1024,3000],f32>
7777
return %2 : !torch.vtensor<[1,1024,3000],f32>
7878
}
79+
80+
// CHECK-LABEL: func.func @transposedConv2D(
81+
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
82+
// CHECK: = linalg.generic
83+
// CHECK-SAME: outs(%[[VAR1:.*]] : tensor<4x2x3x3xf32>) {
84+
// CHECK: %[[VAR2:.*]] = tensor.extract
85+
// CHECK-SAME: : tensor<2x4x3x3xf32>
86+
// CHECK-NEXT: linalg.yield %[[VAR3:.*]] : f32
87+
// CHECK-NEXT: } -> tensor<4x2x3x3xf32>
88+
// CHECK: %[[VAR4:.*]] = linalg.broadcast ins(%[[VAR5:.*]] : tensor<4xf32>) outs(%[[VAR6:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89+
// CHECK: %[[VAR7:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90+
// CHECK-SAME: ins(%[[VAR8:.*]], %[[VAR9:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[VAR10:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91+
// CHECK-NEXT: %[[VAR11:.*]] = tensor.cast %[[VAR12:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
92+
func.func @transposedConv2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32> attributes {torch.assume_strict_symbolic_shapes} {
93+
%int0 = torch.constant.int 0
94+
%true = torch.constant.bool true
95+
%int1 = torch.constant.int 1
96+
%int2 = torch.constant.int 2
97+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_2_4_3_3_torch.float32> : tensor<2x4x3x3xf32>) : !torch.vtensor<[2,4,3,3],f32>
98+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
99+
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
100+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
101+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
102+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
103+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1,2,5,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,10,14],f32>
104+
return %6 : !torch.vtensor<[1,4,10,14],f32>
105+
}
106+
107+
// CHECK-LABEL: func.func @groupedConvolution2D(
108+
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
109+
// CHECK: %[[VAR1:.*]] = linalg.broadcast ins(%[[VAR2:.*]] : tensor<4xf32>) outs(%[[VAR3:.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
110+
// CHECK: %[[VAR4:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
111+
// CHECK-SAME: ins(%[[VAR5:.*]], %[[VAR6:.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[VAR7:.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
112+
// CHECK-NEXT: %[[VAR8:.*]] = tensor.collapse_shape
113+
// CHECK-SAME: tensor<1x2x2x5x7xf32> into tensor<1x4x5x7xf32>
114+
func.func @groupedConvolution2D(%arg0: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32> attributes {torch.assume_strict_symbolic_shapes} {
115+
%int0 = torch.constant.int 0
116+
%false = torch.constant.bool false
117+
%int1 = torch.constant.int 1
118+
%int2 = torch.constant.int 2
119+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_4_2_3_3_torch.float32> : tensor<4x2x3x3xf32>) : !torch.vtensor<[4,2,3,3],f32>
120+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
121+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
122+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
123+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
124+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
125+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %false, %5, %int2 : !torch.vtensor<[1,4,5,7],f32>, !torch.vtensor<[4,2,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,5,7],f32>
126+
return %6 : !torch.vtensor<[1,4,5,7],f32>
127+
}
128+
129+
// CHECK-LABEL: func.func @transposedGroupedConvolution2D(
130+
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
131+
// CHECK: %[[VAR1:.*]] = linalg.broadcast ins(%[[VAR2:.*]] : tensor<4xf32>) outs(%[[VAR3:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
132+
// CHECK: %[[VAR4:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
133+
// CHECK-SAME: ins(%[[VAR5:.*]], %[[VAR6:.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[VAR7:.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
134+
// CHECK-NEXT: %[[VAR8:.*]] = tensor.collapse_shape
135+
// CHECK-SAME: tensor<1x2x2x11x15xf32> into tensor<1x4x11x15xf32>
136+
func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32> attributes {torch.assume_strict_symbolic_shapes} {
137+
%int0 = torch.constant.int 0
138+
%true = torch.constant.bool true
139+
%int1 = torch.constant.int 1
140+
%int2 = torch.constant.int 2
141+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_2_2_3_3_torch.float32> : tensor<2x2x3x3xf32>) : !torch.vtensor<[2,2,3,3],f32>
142+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
143+
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
144+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
145+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
146+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
147+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int2 : !torch.vtensor<[1,2,5,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,10,14],f32>
148+
return %6 : !torch.vtensor<[1,4,10,14],f32>
149+
}
150+

0 commit comments

Comments
 (0)