Skip to content

Commit 4d27ee3

Browse files
author
Ivan Garcia
committed
Addressing zjgarvey's feedback.
1 parent 762af3d commit 4d27ee3

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -956,21 +956,26 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
956956
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
957957
}
958958

959-
// This code was moved earlier because in the grouped transposed convolution
960-
// case we need to expand before doing the dimension permutation. For the
961-
// grouped non-transposed convolution, we don't need to do filter/channel
962-
// dimension flipping, we can just expand the group from the filter in place
963-
// to have the group dimension in front:
959+
// The expandWeight lambda function below is used to expand the group
960+
// dimension. For the normal case the group dimension is expanded out
961+
// of the output filter dimension:
964962
// expand F,C,H,W -> G,F/G,C,H,W
965963
//
966-
// When we have grouped transposed convolution we need to first expand the
967-
// input channel: expand C,F,H,W -> G,C/G,F,H,W
964+
// Note that the group dimension has to be the first dimension. For the
965+
// transposed convolution case, the group convolution is extracted out
966+
// of the input channel dimension. But note that the input channel
967+
// dimension is interchanged with the output filter dimension (due to
968+
// the transposed operation). Because of this the group and input
969+
// channel dimensions will not be adjacent and the expand_shape
970+
// operation will not work.
968971
//
969-
// And then flip the output filters with the input channel to make it linalg
970-
// compatible: permute G,C/G,F,H,W -> G,F,C/G,H,W
972+
// For this reason, in the transposed convolution case the expandWeight
973+
// lambda needs to be executed before this dimension flipping by doing
974+
// these two steps:
971975
//
972-
// Notice that if the flipping happens first, then we can't move the group
973-
// dimension to the front as the linalg convolution operation requires.
976+
// Expansion: C,F,H,W -> G,C/G,F,H,W
977+
//
978+
// Dimension interchange: G,C/G,F,H,W -> G,F,C/G,H,W
974979
//
975980
auto expandWeight = [&](Value tensor) {
976981
auto inType = cast<RankedTensorType>(tensor.getType());

0 commit comments

Comments
 (0)