@@ -78,17 +78,18 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.
78
78
}
79
79
80
80
// CHECK-LABEL: func.func @transposedConv2D(
81
- // CHECK-SAME: %[[arg0 :.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
81
+ // CHECK-SAME: %[[INPUT_TENSOR :.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
82
82
// CHECK: = linalg.generic
83
- // CHECK-SAME: outs(%[[VAR1 :.*]] : tensor<4x2x3x3xf32>) {
84
- // CHECK: %[[VAR2 :.*]] = tensor.extract
83
+ // CHECK-SAME: outs(%[[BROADCASTED_WEIGHTS_INIT :.*]] : tensor<4x2x3x3xf32>) {
84
+ // CHECK: %[[WEIGHTS :.*]] = tensor.extract
85
85
// CHECK-SAME: : tensor<2x4x3x3xf32>
86
- // CHECK-NEXT: linalg.yield %[[VAR3 :.*]] : f32
86
+ // CHECK-NEXT: linalg.yield %[[BROADCASTED_WEIGHTS :.*]] : f32
87
87
// CHECK-NEXT: } -> tensor<4x2x3x3xf32>
88
- // CHECK: %[[VAR4:.*]] = linalg.broadcast ins(%[[VAR5:.*]] : tensor<4xf32>) outs(%[[VAR6:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89
- // CHECK: %[[VAR7:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90
- // CHECK-SAME: ins(%[[VAR8:.*]], %[[VAR9:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[VAR10:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91
- // CHECK-NEXT: %[[VAR11:.*]] = tensor.cast %[[VAR12:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
88
+ // CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89
+ // CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90
+ // CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91
+ // CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[CONV_RESULT:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
92
+ // CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
92
93
func.func @transposedConv2D (%arg0: !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
93
94
%int0 = torch.constant.int 0
94
95
%true = torch.constant.bool true
@@ -105,11 +106,11 @@ func.func @transposedConv2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vten
105
106
}
106
107
107
108
// CHECK-LABEL: func.func @groupedConvolution2D(
108
- // CHECK-SAME: %[[arg0 :.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
109
- // CHECK: %[[VAR1 :.*]] = linalg.broadcast ins(%[[VAR2 :.*]] : tensor<4xf32>) outs(%[[VAR3 :.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
110
- // CHECK: %[[VAR4 :.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
111
- // CHECK-SAME: ins(%[[VAR5 :.*]], %[[VAR6 :.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[VAR7 :.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
112
- // CHECK-NEXT: %[[VAR8 :.*]] = tensor.collapse_shape
109
+ // CHECK-SAME: %[[INPUT_TENSOR :.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
110
+ // CHECK: %[[BROADCASTED_BIAS :.*]] = linalg.broadcast ins(%[[BIAS :.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT :.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
111
+ // CHECK: %[[CONV_RESULT :.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
112
+ // CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED :.*]], %[[BROADCASTED_WEIGHTS :.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[BROADCASTED_BIAS :.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
113
+ // CHECK-NEXT: %[[OUTPUT_TENSOR :.*]] = tensor.collapse_shape
113
114
// CHECK-SAME: tensor<1x2x2x5x7xf32> into tensor<1x4x5x7xf32>
114
115
func.func @groupedConvolution2D (%arg0: !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,5 ,7 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
115
116
%int0 = torch.constant.int 0
@@ -127,12 +128,14 @@ func.func @groupedConvolution2D(%arg0: !torch.vtensor<[1,4,5,7],f32>) -> !torch.
127
128
}
128
129
129
130
// CHECK-LABEL: func.func @transposedGroupedConvolution2D(
130
- // CHECK-SAME: %[[arg0 :.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
131
- // CHECK: %[[VAR1 :.*]] = linalg.broadcast ins(%[[VAR2 :.*]] : tensor<4xf32>) outs(%[[VAR3 :.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
132
- // CHECK: %[[VAR4 :.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
133
- // CHECK-SAME: ins(%[[VAR5 :.*]], %[[VAR6 :.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[VAR7 :.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
134
- // CHECK-NEXT: %[[VAR8 :.*]] = tensor.collapse_shape
131
+ // CHECK-SAME: %[[INPUT_TENSOR :.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
132
+ // CHECK: %[[BROADCASTED_BIAS :.*]] = linalg.broadcast ins(%[[BIAS :.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT :.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
133
+ // CHECK: %[[CONV_RESULT :.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
134
+ // CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED :.*]], %[[BROADCASTED_WEIGHTS :.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[BROADCASTED_BIAS :.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
135
+ // CHECK-NEXT: %[[COLLAPSED_TENSOR :.*]] = tensor.collapse_shape
135
136
// CHECK-SAME: tensor<1x2x2x11x15xf32> into tensor<1x4x11x15xf32>
137
+ // CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[COLLAPSED_TENSOR:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
138
+ // CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
136
139
func.func @transposedGroupedConvolution2D (%arg0: !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
137
140
%int0 = torch.constant.int 0
138
141
%true = torch.constant.bool true
0 commit comments