@@ -76,3 +76,75 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.
76
76
%2 = torch.aten.convolution %arg0 , %arg1 , %arg2 , %0 , %0 , %0 , %false , %1 , %int1 : !torch.vtensor <[1 ,80 ,3000 ],f32 >, !torch.vtensor <[1024 ,80 ,3 ],f32 >, !torch.vtensor <[1024 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,1024 ,3000 ],f32 >
77
77
return %2 : !torch.vtensor <[1 ,1024 ,3000 ],f32 >
78
78
}
79
+
80
+ // CHECK-LABEL: func.func @transposedConv2D(
81
+ // CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
82
+ // CHECK: = linalg.generic
83
+ // CHECK-SAME: outs(%[[VAR1:.*]] : tensor<4x2x3x3xf32>) {
84
+ // CHECK: %[[VAR2:.*]] = tensor.extract
85
+ // CHECK-SAME: : tensor<2x4x3x3xf32>
86
+ // CHECK-NEXT: linalg.yield %[[VAR3:.*]] : f32
87
+ // CHECK-NEXT: } -> tensor<4x2x3x3xf32>
88
+ // CHECK: %[[VAR4:.*]] = linalg.broadcast ins(%[[VAR5:.*]] : tensor<4xf32>) outs(%[[VAR6:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89
+ // CHECK: %[[VAR7:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90
+ // CHECK-SAME: ins(%[[VAR8:.*]], %[[VAR9:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[VAR10:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91
+ // CHECK-NEXT: %[[VAR11:.*]] = tensor.cast %[[VAR12:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
92
+ func.func @transposedConv2D (%arg0: !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
93
+ %int0 = torch.constant.int 0
94
+ %true = torch.constant.bool true
95
+ %int1 = torch.constant.int 1
96
+ %int2 = torch.constant.int 2
97
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_2_4_3_3_torch.float32 > : tensor <2 x4 x3 x3 xf32 >) : !torch.vtensor <[2 ,4 ,3 ,3 ],f32 >
98
+ %1 = torch.vtensor.literal (dense_resource <torch_tensor_4_torch.float32 > : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
99
+ %2 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
100
+ %3 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
101
+ %4 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
102
+ %5 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
103
+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %true , %5 , %int1 : !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >, !torch.vtensor <[2 ,4 ,3 ,3 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
104
+ return %6 : !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
105
+ }
106
+
107
+ // CHECK-LABEL: func.func @groupedConvolution2D(
108
+ // CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
109
+ // CHECK: %[[VAR1:.*]] = linalg.broadcast ins(%[[VAR2:.*]] : tensor<4xf32>) outs(%[[VAR3:.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
110
+ // CHECK: %[[VAR4:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
111
+ // CHECK-SAME: ins(%[[VAR5:.*]], %[[VAR6:.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[VAR7:.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
112
+ // CHECK-NEXT: %[[VAR8:.*]] = tensor.collapse_shape
113
+ // CHECK-SAME: tensor<1x2x2x5x7xf32> into tensor<1x4x5x7xf32>
114
+ func.func @groupedConvolution2D (%arg0: !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,5 ,7 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
115
+ %int0 = torch.constant.int 0
116
+ %false = torch.constant.bool false
117
+ %int1 = torch.constant.int 1
118
+ %int2 = torch.constant.int 2
119
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_4_2_3_3_torch.float32 > : tensor <4 x2 x3 x3 xf32 >) : !torch.vtensor <[4 ,2 ,3 ,3 ],f32 >
120
+ %1 = torch.vtensor.literal (dense_resource <torch_tensor_4_torch.float32 > : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
121
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
122
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
123
+ %4 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
124
+ %5 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
125
+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %false , %5 , %int2 : !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >, !torch.vtensor <[4 ,2 ,3 ,3 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >
126
+ return %6 : !torch.vtensor <[1 ,4 ,5 ,7 ],f32 >
127
+ }
128
+
129
+ // CHECK-LABEL: func.func @transposedGroupedConvolution2D(
130
+ // CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
131
+ // CHECK: %[[VAR1:.*]] = linalg.broadcast ins(%[[VAR2:.*]] : tensor<4xf32>) outs(%[[VAR3:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
132
+ // CHECK: %[[VAR4:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
133
+ // CHECK-SAME: ins(%[[VAR5:.*]], %[[VAR6:.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[VAR7:.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
134
+ // CHECK-NEXT: %[[VAR8:.*]] = tensor.collapse_shape
135
+ // CHECK-SAME: tensor<1x2x2x11x15xf32> into tensor<1x4x11x15xf32>
136
+ func.func @transposedGroupedConvolution2D (%arg0: !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >) -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
137
+ %int0 = torch.constant.int 0
138
+ %true = torch.constant.bool true
139
+ %int1 = torch.constant.int 1
140
+ %int2 = torch.constant.int 2
141
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_2_2_3_3_torch.float32 > : tensor <2 x2 x3 x3 xf32 >) : !torch.vtensor <[2 ,2 ,3 ,3 ],f32 >
142
+ %1 = torch.vtensor.literal (dense_resource <torch_tensor_4_torch.float32 > : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
143
+ %2 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
144
+ %3 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
145
+ %4 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
146
+ %5 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
147
+ %6 = torch.aten.convolution %arg0 , %0 , %1 , %2 , %3 , %4 , %true , %5 , %int2 : !torch.vtensor <[1 ,2 ,5 ,7 ],f32 >, !torch.vtensor <[2 ,2 ,3 ,3 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
148
+ return %6 : !torch.vtensor <[1 ,4 ,10 ,14 ],f32 >
149
+ }
150
+
0 commit comments