Skip to content

Commit

Permalink
Addressing zjgarvey's feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Garcia committed Mar 6, 2025
1 parent 762af3d commit 4d27ee3
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,21 +956,26 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}

// This code was moved earlier because in the grouped transposed convolution
// case we need to expand before doing the dimension permutation. For the
// grouped non-transposed convolution, we don't need to do filter/channel
// dimension flipping, we can just expand the group from the filter in place
// to have the group dimension in front:
// The expandWeight lambda function below is used to expand the group
// dimension. For the normal case the group dimension is expanded out
// of the output filter dimension:
// expand F,C,H,W -> G,F/G,C,H,W
//
// When we have grouped transposed convolution we need to first expand the
// input channel: expand C,F,H,W -> G,C/G,F,H,W
// Note that the group dimension has to be the first dimension. For the
// transposed convolution case, the group convolution is extracted out
// of the input channel dimension. But note that the input channel
// dimension is interchanged with the output filter dimension (due to
// the transposed operation). Because of this the group and input
// channel dimensions will not be adjacent and the expand_shape
// operation will not work.
//
// And then flip the output filters with the input channel to make it linalg
// compatible: permute G,C/G,F,H,W -> G,F,C/G,H,W
// For this reason, in the transposed convolution case the expandWeight
// lambda needs to be executed before this dimension flipping by doing
// these two steps:
//
// Notice that if the flipping happens first, then we can't move the group
// dimension to the front as the linalg convolution operation requires.
// Expansion: C,F,H,W -> G,C/G,F,H,W
//
// Dimension interchange: G,C/G,F,H,W -> G,F,C/G,H,W
//
auto expandWeight = [&](Value tensor) {
auto inType = cast<RankedTensorType>(tensor.getType());
Expand Down

0 comments on commit 4d27ee3

Please sign in to comment.