@@ -956,21 +956,26 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
956
956
pad = rewriter.create <arith::TruncIOp>(op.getLoc (), inputDTy, pad);
957
957
}
958
958
959
- // This code was moved earlier because in the grouped transposed convolution
960
- // case we need to expand before doing the dimension permutation. For the
961
- // grouped non-transposed convolution, we don't need to do filter/channel
962
- // dimension flipping, we can just expand the group from the filter in place
963
- // to have the group dimension in front:
959
+ // The expandWeight lambda function below is used to expand the group
960
+ // dimension. For the normal case the group dimension is expanded out
961
+ // of the output filter dimension:
964
962
// expand F,C,H,W -> G,F/G,C,H,W
965
963
//
966
- // When we have grouped transposed convolution we need to first expand the
967
- // input channel: expand C,F,H,W -> G,C/G,F,H,W
964
+ // Note that the group dimension has to be the first dimension. For the
965
+ // transposed convolution case, the group convolution is extracted out
966
+ // of the input channel dimension. But note that the input channel
967
+ // dimension is interchanged with the output filter dimension (due to
968
+ // the transposed operation). Because of this the group and input
969
+ // channel dimensions will not be adjacent and the expand_shape
970
+ // operation will not work.
968
971
//
969
- // And then flip the output filters with the input channel to make it linalg
970
- // compatible: permute G,C/G,F,H,W -> G,F,C/G,H,W
972
+ // For this reason, in the transposed convolution case the expandWeight
973
+ // lambda needs to be executed before this dimension flipping by doing
974
+ // these two steps:
971
975
//
972
- // Notice that if the flipping happens first, then we can't move the group
973
- // dimension to the front as the linalg convolution operation requires.
976
+ // Expansion: C,F,H,W -> G,C/G,F,H,W
977
+ //
978
+ // Dimension interchange: G,C/G,F,H,W -> G,F,C/G,H,W
974
979
//
975
980
auto expandWeight = [&](Value tensor) {
976
981
auto inType = cast<RankedTensorType>(tensor.getType ());
0 commit comments