Skip to content

Commit 762af3d

Browse files
author
Ivan Garcia
committed
Addressing Sayan's feedback.
1 parent b3d28ff commit 762af3d

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,18 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.
7878
}
7979

8080
// CHECK-LABEL: func.func @transposedConv2D(
81-
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
81+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
8282
// CHECK: = linalg.generic
83-
// CHECK-SAME: outs(%[[VAR1:.*]] : tensor<4x2x3x3xf32>) {
84-
// CHECK: %[[VAR2:.*]] = tensor.extract
83+
// CHECK-SAME: outs(%[[BROADCASTED_WEIGHTS_INIT:.*]] : tensor<4x2x3x3xf32>) {
84+
// CHECK: %[[WEIGHTS:.*]] = tensor.extract
8585
// CHECK-SAME: : tensor<2x4x3x3xf32>
86-
// CHECK-NEXT: linalg.yield %[[VAR3:.*]] : f32
86+
// CHECK-NEXT: linalg.yield %[[BROADCASTED_WEIGHTS:.*]] : f32
8787
// CHECK-NEXT: } -> tensor<4x2x3x3xf32>
88-
// CHECK: %[[VAR4:.*]] = linalg.broadcast ins(%[[VAR5:.*]] : tensor<4xf32>) outs(%[[VAR6:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89-
// CHECK: %[[VAR7:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90-
// CHECK-SAME: ins(%[[VAR8:.*]], %[[VAR9:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[VAR10:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91-
// CHECK-NEXT: %[[VAR11:.*]] = tensor.cast %[[VAR12:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
88+
// CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89+
// CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90+
// CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91+
// CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[CONV_RESULT:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
92+
// CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
9293
func.func @transposedConv2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32> attributes {torch.assume_strict_symbolic_shapes} {
9394
%int0 = torch.constant.int 0
9495
%true = torch.constant.bool true
@@ -105,11 +106,11 @@ func.func @transposedConv2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vten
105106
}
106107

107108
// CHECK-LABEL: func.func @groupedConvolution2D(
108-
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
109-
// CHECK: %[[VAR1:.*]] = linalg.broadcast ins(%[[VAR2:.*]] : tensor<4xf32>) outs(%[[VAR3:.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
110-
// CHECK: %[[VAR4:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
111-
// CHECK-SAME: ins(%[[VAR5:.*]], %[[VAR6:.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[VAR7:.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
112-
// CHECK-NEXT: %[[VAR8:.*]] = tensor.collapse_shape
109+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
110+
// CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
111+
// CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
112+
// CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
113+
// CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.collapse_shape
113114
// CHECK-SAME: tensor<1x2x2x5x7xf32> into tensor<1x4x5x7xf32>
114115
func.func @groupedConvolution2D(%arg0: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32> attributes {torch.assume_strict_symbolic_shapes} {
115116
%int0 = torch.constant.int 0
@@ -127,12 +128,14 @@ func.func @groupedConvolution2D(%arg0: !torch.vtensor<[1,4,5,7],f32>) -> !torch.
127128
}
128129

129130
// CHECK-LABEL: func.func @transposedGroupedConvolution2D(
130-
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
131-
// CHECK: %[[VAR1:.*]] = linalg.broadcast ins(%[[VAR2:.*]] : tensor<4xf32>) outs(%[[VAR3:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
132-
// CHECK: %[[VAR4:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
133-
// CHECK-SAME: ins(%[[VAR5:.*]], %[[VAR6:.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[VAR7:.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
134-
// CHECK-NEXT: %[[VAR8:.*]] = tensor.collapse_shape
131+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
132+
// CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
133+
// CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
134+
// CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
135+
// CHECK-NEXT: %[[COLLAPSED_TENSOR:.*]] = tensor.collapse_shape
135136
// CHECK-SAME: tensor<1x2x2x11x15xf32> into tensor<1x4x11x15xf32>
137+
// CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[COLLAPSED_TENSOR:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
138+
// CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
136139
func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32> attributes {torch.assume_strict_symbolic_shapes} {
137140
%int0 = torch.constant.int 0
138141
%true = torch.constant.bool true

0 commit comments

Comments
 (0)