diff --git a/externals/llvm-project b/externals/llvm-project index 65e44b4301eb..6d847b1aada5 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 65e44b4301eb1ae6838ad101f35a7d987949e13b +Subproject commit 6d847b1aada50d59c3e29f2e7eff779c0ee8182c diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 523c40a1dddc..97108a38bebb 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1623,7 +1623,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { int64_t shape; }; - // Transpose needs to done if transposeDims are not non-monotonically + // Transpose needs to done if transposedDims are not non-monotonically // increasing. E.g. [0, 1, 2, 3]: No transpose [1, 0, 2, 3]: Transpose dim0 // and dim1 The order need not be sequential, since one or more dims may // have been removed due to broadcasting. @@ -1739,19 +1739,14 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto transposedLhsType = RankedTensorType::get( makeShapeLLVMCompatible(transposedLhsShape), rhsElemTy); - std::optional transposedLhsDimsConst = - tosa::getConstTensor( - rewriter, op, - /*vec=*/transposedLhsDims, - /*shape=*/{static_cast(transposedLhsDims.size())}); - lhsReshapeInput = rewriter .create( op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedLhsType), - rankBroadcastedLhs, transposedLhsDimsConst.value()) + rankBroadcastedLhs, + rewriter.getDenseI32ArrayAttr(transposedLhsDims)) .getResult(); } @@ -1819,22 +1814,16 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto transposedRhsValue = rankBroadcastedRhs; - if (rhsNeedsTranspose) { - std::optional transposedRhsDimsConst = - tosa::getConstTensor( - rewriter, op, - /*vec=*/transposedRhsDims, - /*shape=*/{static_cast(transposedRhsDims.size())}); - + if (rhsNeedsTranspose) transposedRhsValue = rewriter .create( op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedRhsType), - rankBroadcastedRhs, transposedRhsDimsConst.value()) + rankBroadcastedRhs, + rewriter.getDenseI32ArrayAttr(transposedRhsDims)) .getResult(); - } // reshape matmulRhs = rewriter.create( @@ -1985,13 +1974,6 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { tosa::getTosaConstShape(rewriter, op->getLoc(), reshapedOpShape)); if (opNeedsTranspose) { - - std::optional transposedOpShapeConst = - tosa::getConstTensor( - rewriter, op, - /*vec=*/transposedOpDims, - /*shape=*/{static_cast(transposedOpDims.size())}); - auto transposedOpType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOpShape), outputElemTy); output = rewriter @@ -1999,7 +1981,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedOpType), - reshapedOp.getResult(), transposedOpShapeConst.value()) + reshapedOp.getResult(), + rewriter.getDenseI32ArrayAttr(transposedOpDims)) .getResult(); } else { @@ -2178,19 +2161,13 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { std::swap(transposedRhsShape[rhsRank - 1], transposedRhsShape[rhsRank - 2]); std::swap(transposedRhsDims[rhsRank - 1], transposedRhsDims[rhsRank - 2]); - std::optional transposedRhsShapeConst = - tosa::getConstTensor( - rewriter, op, - /*vec=*/transposedRhsDims, - /*shape=*/{static_cast(transposedRhsDims.size())}); - auto transposedRhsType = RankedTensorType::get( makeShapeLLVMCompatible(transposedRhsShape), rhsElemTy); rhs = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( transposedRhsType), - rhs, transposedRhsShapeConst.value()); + rhs, rewriter.getDenseI32ArrayAttr(transposedRhsDims)); Value matmulOutput; if (failed( @@ -2350,9 +2327,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); - // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. - // The Torch OFM computation uses 2*pad in each spatial direction, implying - // the same t=b and l=r values for TOSA. + // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D + // padding {height, width}. The Torch OFM computation uses 2*pad in each + // spatial direction, implying the same top=bottom=height and left=right=width + // values for TOSA. SmallVector padding( {padding_2d[0], padding_2d[0], padding_2d[1], padding_2d[1]}); @@ -2369,10 +2347,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. // Perform the necessary transformations. - std::optional nchwToNhwcTransposeConst = - tosa::getConstTensor(rewriter, op, - /*vec=*/{0, 2, 3, 1}, - /*shape=*/{static_cast(4)}); + SmallVector nchwToNhwcDims({0, 2, 3, 1}); SmallVector transposedInputShape( {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); auto transposedInputType = RankedTensorType::get( @@ -2382,7 +2357,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedInputType), input, - nchwToNhwcTransposeConst.value()) + rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) .getResult(); SmallVector transformedWeightShape; @@ -2400,16 +2375,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transformedWeightType), weight, - nchwToNhwcTransposeConst.value()) + rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) .getResult(); outputCDim = transformedWeightShape[0]; } else if (weightShape[1] == 1) { // depthwise convolution: O(I/G)HW-> HWIM) // transpose: O(I/G)HW -> HWO(I/G) - std::optional transposeConst = - tosa::getConstTensor(rewriter, op, - /*vec=*/{2, 3, 0, 1}, - /*shape=*/{static_cast(4)}); + SmallVector transposedDims({2, 3, 0, 1}); SmallVector transposedWeightShape = { weightShape[2], weightShape[3], weightShape[0], weightShape[1]}; auto transposedWeightType = RankedTensorType::get( @@ -2419,7 +2391,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedWeightType), weight, - transposeConst.value()) + rewriter.getDenseI32ArrayAttr(transposedDims)) .getResult(); // reshape: HWO(I/G) -> HWIM @@ -2456,14 +2428,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t inputWDim = inputShape[3]; int64_t weightHDim = weightShape[2]; int64_t weightWDim = weightShape[3]; - outputHDim = (inputHDim + padding[0] + padding[1] - - dilation[0] * (weightHDim - 1) - 1) / - stride[0] + - 1; - outputWDim = (inputWDim + padding[2] + padding[3] - - dilation[1] * (weightWDim - 1) - 1) / - stride[1] + - 1; + + // fullDim = + // inputDim + padBefore + padAfter - dilation * (weightDim - 1) - 1 + // According to TOSA spec: + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d, fullDim values + // must be divisible by stride values. + int64_t fullHDim = inputHDim + padding[0] + padding[1] - + dilation[0] * (weightHDim - 1) - 1; + int64_t remainderHDim = fullHDim % stride[0]; + if (remainderHDim != 0) { + if (remainderHDim > padding[1]) { + SmallVector startHSlice(inputTy.getRank(), 0); + SmallVector sizeHSlice(transposedInputShape); + // TOSA uses NHWC, so we will slice dim 1 for Height value + sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); + transposedInput = rewriter.create( + op->getLoc(), RankedTensorType::get(sizeHSlice, inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startHSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeHSlice)); + fullHDim = fullHDim - padding[1]; + padding[1] = 0; + } else { + fullHDim = fullHDim - padding[1]; + padding[1] = padding[1] - remainderHDim; + fullHDim = fullHDim + padding[1]; + } + } + outputHDim = fullHDim / stride[0] + 1; + + int64_t fullWDim = inputWDim + padding[2] + padding[3] - + dilation[1] * (weightWDim - 1) - 1; + int64_t remainderWDim = fullWDim % stride[1]; + if (remainderWDim != 0) { + if (remainderWDim > padding[3]) { + SmallVector startWSlice(inputTy.getRank(), 0); + SmallVector sizeWSlice( + dyn_cast(transposedInput.getType()).getShape()); + // TOSA uses NHWC, so we will slice dim 2 for Width value + sizeWSlice[2] = inputWDim - (remainderWDim - padding[3]); + transposedInput = rewriter.create( + op->getLoc(), RankedTensorType::get(sizeWSlice, inputElemTy), + transposedInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), startWSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeWSlice)); + fullHDim = fullHDim - padding[3]; + padding[3] = 0; + } else { + fullWDim = fullWDim - padding[3]; + padding[3] = padding[3] - remainderWDim; + fullWDim = fullWDim + padding[3]; + } + } + outputWDim = fullWDim / stride[1] + 1; } else { outputHDim = kUnknownSize; outputWDim = kUnknownSize; @@ -2503,10 +2521,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Unhandled convolution type"); } - std::optional nhwcToNchwTransposeConst = - tosa::getConstTensor(rewriter, op, - /*vec=*/{0, 3, 1, 2}, - /*shape=*/{static_cast(4)}); + SmallVector nhwcToNchwDims({0, 3, 1, 2}); SmallVector transposedOutputShape( {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); auto transposedOutputType = RankedTensorType::get( @@ -2516,7 +2531,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedOutputType), - convOpResult, nhwcToNchwTransposeConst.value()) + convOpResult, rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) .getResult(); Value rescaledResult = transposedOutput; @@ -3071,12 +3086,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (auto v : dimListInt) dimListInt32.push_back(v); - auto transposeDimsConst = mlir::tosa::getConstTensor( - rewriter, op.getOperation(), dimListInt32, {selfRank}); - rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - transposeDimsConst.value()); + rewriter.getDenseI32ArrayAttr(dimListInt32)); return success(); } @@ -3841,19 +3853,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "dim0 and dim1 must be less than tensor rank"); - SmallVector transposeDims; + SmallVector transposedDims; for (auto i = 0; i < selfType.getRank(); ++i) - transposeDims.push_back(i); + transposedDims.push_back(i); - transposeDims[dim0] = dim1; - transposeDims[dim1] = dim0; - - auto transposeDimsConst = mlir::tosa::getConstTensor( - rewriter, op.getOperation(), transposeDims, {selfType.getRank()}); + transposedDims[dim0] = dim1; + transposedDims[dim1] = dim0; rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - transposeDimsConst.value()); + rewriter.getDenseI32ArrayAttr(transposedDims)); return success(); } @@ -5744,27 +5753,22 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } } - // Apply the transposeDims vector on input to generate a transposed form. + // Apply the transposedDims vector on input to generate a transposed form. Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter, - Value input, ArrayRef transposeDims) const { + Value input, ArrayRef transposedDims) const { auto inputTy = cast(input.getType()); auto inputElemTy = inputTy.getElementType(); auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); - auto inputRank = inputTy.getRank(); - - std::optional transposeDimsConst = tosa::getConstTensor( - rewriter, op, - /*vec=*/transposeDims, - /*shape=*/{static_cast(inputRank)}); SmallVector transposedInputShape; - for (auto &dim : transposeDims) + for (auto &dim : transposedDims) transposedInputShape.push_back(inputShape[dim]); auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); return rewriter - .create(op->getLoc(), transposedInputType, input, - transposeDimsConst.value()) + .create( + op->getLoc(), transposedInputType, input, + rewriter.getDenseI32ArrayAttr(transposedDims)) .getResult(); } @@ -6588,6 +6592,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Pad value needs to be a scalar constant for conversion to " "TOSA pad operation"); + padTensor = rewriter.create( + op->getLoc(), RankedTensorType::get({1}, selfElemTy), padTensor, + tosa::getTosaConstShape(rewriter, op->getLoc(), {1})); + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, padsList1, padTensor); @@ -6686,10 +6694,7 @@ ConvertAtenOp::matchAndRewrite( auto inputShape = inputTy.getShape(); auto inputElemTy = inputTy.getElementType(); // TOSA works in NHWC. Perform the necessary transformations. - std::optional nchwToNhwcTransposeConst = - tosa::getConstTensor(rewriter, op, - /*vec=*/{0, 2, 3, 1}, - /*shape=*/{static_cast(4)}); + SmallVector nchwToNhwcDims({0, 2, 3, 1}); SmallVector transposedInputShape( {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); auto transposedInputTy = RankedTensorType::get( @@ -6698,7 +6703,7 @@ ConvertAtenOp::matchAndRewrite( rewriter .create( op->getLoc(), getTypeConverter()->convertType(transposedInputTy), - input, nchwToNhwcTransposeConst.value()) + input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) .getResult(); auto inputHeight = transposedInputShape[1]; @@ -6836,15 +6841,12 @@ ConvertAtenOp::matchAndRewrite( auto resultType = cast(typeConverter->convertType(op.getType())); - std::optional nhwcToNchwTransposeConst = - tosa::getConstTensor(rewriter, op, - /*vec=*/{0, 3, 1, 2}, - /*shape=*/{static_cast(4)}); + SmallVector nhwcToNchwDims({0, 3, 1, 2}); rewriter .replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), resizeOpResult, - nhwcToNchwTransposeConst.value()) + rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) .getResult(); return success(); @@ -7191,11 +7193,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedDims.push_back(static_cast(dim1)); transposedDims.push_back(static_cast(dim2)); - auto transposedDimsConst = tosa::getConstTensor( - rewriter, op, - /*vec=*/transposedDims, - /*shape=*/{static_cast(selfRank)}); - for (auto &dim : transposedDims) transposedInputShape.push_back(selfShape[dim]); @@ -7203,7 +7200,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( makeShapeLLVMCompatible(transposedInputShape), selfElemTy); selfTransposed = rewriter.create( - op->getLoc(), transposedInputType, self, transposedDimsConst.value()); + op->getLoc(), transposedInputType, self, + rewriter.getDenseI32ArrayAttr(transposedDims)); } // Define shape for mask tensor based on rank @@ -7452,14 +7450,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( i++; } - auto permutedDimsConst = - tosa::getConstTensor(rewriter, op, - /*vec=*/permutedDims, - /*shape=*/{static_cast(outRank)}); - - auto result = rewriter.create(op->getLoc(), resultType, - diagonalTensor.value(), - permutedDimsConst.value()); + auto result = rewriter.create( + op->getLoc(), resultType, diagonalTensor.value(), + rewriter.getDenseI32ArrayAttr(permutedDims)); rewriter.replaceOp(op, result.getResult()); @@ -8909,14 +8902,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } permutedDims.push_back(static_cast(dim + 1)); - auto permutedDimsConst = tosa::getConstTensor( - rewriter, op, - /*vec=*/permutedDims, - /*shape=*/{static_cast(selfRank + 1)}) - .value(); - auto result = rewriter.create( - op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst); + op->getLoc(), resultType, reshapeOp.getResult(), + rewriter.getDenseI32ArrayAttr(permutedDims)); rewriter.replaceOp(op, {result.getResult()}); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index dee098855744..bf5d665af097 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -873,7 +873,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; } - return convertReduceOpCommon( + return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, keep_dims, output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 57e5f12ac94b..74fa43e49d6d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3391,8 +3391,6 @@ "ViewDtypeStaticModule_basic", "Unfold_Module_Rank_Zero_Size_Zero_basic", "ArangeZeroElementOutputModule_basic", - "NumpyTRank0Module_basic", - "Permute0RankModule_basic", "SliceOutOfUpperBoundIndexModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStartEqEndModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 635e4e7affac..6b5ee5fb533b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -794,16 +794,15 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: ! // ----- // CHECK-LABEL: func.func @torch.aten.permute$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> +// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<3x4x2xf32>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,2,4],f32> // CHECK: } func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> { %int1 = torch.constant.int 1 @@ -955,7 +954,7 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // ----- // CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 7 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 @@ -965,14 +964,12 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_10]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> -// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> -// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> -// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,1,1],f32> +// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x512x7x7xf32>) -> tensor<1x7x7x512xf32> +// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x1x1x512xf32>) -> tensor<1x512x1x1xf32> +// CHECK: %[[VAL_13:.*]] = tensor.cast %[[VAL_12]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,512,1,1],f32> // CHECK: } func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> { %int7 = torch.constant.int 7 @@ -1480,23 +1477,21 @@ func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.__interpolate.size_list_scale_list.bilinear( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,16,135,240],f32> -> tensor<1x16x135x240xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.constant.bool false // CHECK: %[[VAL_4:.*]] = torch.constant.str "bilinear" // CHECK: %[[VAL_5:.*]] = torch.constant.float 2.000000e+00 // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.float, !torch.float) -> !torch.list -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] : (tensor<1x16x135x240xf32>, tensor<4xi32>) -> tensor<1x135x240x16xf32> -// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_12:.*]] = tosa.resize %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] {mode = "BILINEAR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x270x480x16xf32>, tensor<4xi32>) -> tensor<1x16x270x480xf32> -// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> -// CHECK: return %[[VAL_15]] : !torch.vtensor<[1,16,270,480],f32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x16x135x240xf32>) -> tensor<1x135x240x16xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "BILINEAR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32> // CHECK: } func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { %none = torch.constant.none @@ -1511,23 +1506,21 @@ func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch. // ----- // CHECK-LABEL: func.func @torch.aten.__interpolate.size_list_scale_list.nearest( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,16,135,240],f32> -> tensor<1x16x135x240xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.constant.bool false // CHECK: %[[VAL_4:.*]] = torch.constant.str "nearest" // CHECK: %[[VAL_5:.*]] = torch.constant.float 2.000000e+00 // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.float, !torch.float) -> !torch.list -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] : (tensor<1x16x135x240xf32>, tensor<4xi32>) -> tensor<1x135x240x16xf32> -// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_12:.*]] = tosa.resize %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] {mode = "NEAREST_NEIGHBOR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x270x480x16xf32>, tensor<4xi32>) -> tensor<1x16x270x480xf32> -// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> -// CHECK: return %[[VAL_15]] : !torch.vtensor<[1,16,270,480],f32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x16x135x240xf32>) -> tensor<1x135x240x16xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "NEAREST_NEIGHBOR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32> // CHECK: } func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { %none = torch.constant.none @@ -1623,12 +1616,12 @@ func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vt // ----- // CHECK-LABEL: func.func @torch.aten.prod.dim_int$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = torch.constant.bool true // CHECK: %[[VAL_4:.*]] = torch.constant.none -// CHECK: %[[VAL_5:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_product %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,2,1],f32> // CHECK: } @@ -2085,24 +2078,23 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?, // ----- // CHECK-LABEL: func.func @torch.aten.diagonal$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 // CHECK: %[[VAL_4:.*]] = torch.constant.int -2 -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>, tensor<1xi8>) -> tensor<5x6x4x3xi32> -// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[0, 0, 2, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[5, 6, 2, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : (tensor<5x6x4x3xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x6x2x3xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> -// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[5, 6, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<5x6x2x1xi32>, !tosa.shape<3>) -> tensor<5x6x2xi32> -// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> -// CHECK: return %[[VAL_16]] : !torch.vtensor<[5,6,2],si32> +// CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<3x4x5x6xi32>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>, tensor<1xi8>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[0, 0, 2, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[5, 6, 2, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : (tensor<5x6x4x3xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x6x2x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[5, 6, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]], %[[VAL_13]] : (tensor<5x6x2x1xi32>, !tosa.shape<3>) -> tensor<5x6x2xi32> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,6,2],si32> // CHECK: } func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> { %dim1 = torch.constant.int 1 @@ -2291,13 +2283,11 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc // CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[C6]], %[[C6]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[L3:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[PERMS_IN:.*]] = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[TRANSPOSE_IN:.*]] = tosa.transpose %[[TENSOR]], %[[PERMS_IN]] : (tensor<1x64x56xf32>, tensor<3xi32>) -> tensor<64x56x1xf32> +// CHECK: %[[TRANSPOSE_IN:.*]] = tosa.transpose %[[TENSOR]] {perms = array} : (tensor<1x64x56xf32>) -> tensor<64x56x1xf32> // CHECK: %[[CONST_SHAPE_IN:.*]] = tosa.const_shape {value = dense<[1, 64, 56, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %[[TRANSPOSE_IN]], %[[CONST_SHAPE_IN]] : (tensor<64x56x1xf32>, !tosa.shape<4>) -> tensor<1x64x56x1xf32> // CHECK: %[[POOL:.*]] = tosa.avg_pool2d %[[RESHAPE_IN]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x64x56x1xf32>) -> tensor<1x59x51x1xf32> -// CHECK: %[[PERMS_OUT:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[TRANSPOSE_OUT:.*]] = tosa.transpose %[[POOL]], %[[PERMS_OUT]] : (tensor<1x59x51x1xf32>, tensor<4xi32>) -> tensor<1x1x59x51xf32> +// CHECK: %[[TRANSPOSE_OUT:.*]] = tosa.transpose %[[POOL]] {perms = array} : (tensor<1x59x51x1xf32>) -> tensor<1x1x59x51xf32> // CHECK: %[[CONST_SHAPE_OUT:.*]] = tosa.const_shape {value = dense<[1, 59, 51]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[TRANSPOSE_OUT]], %[[CONST_SHAPE_OUT]] : (tensor<1x1x59x51xf32>, !tosa.shape<3>) -> tensor<1x59x51xf32> // CHECK: %[[CAST:.*]] = tensor.cast %[[RESHAPE_OUT]] : tensor<1x59x51xf32> to tensor<1x59x51xf32> @@ -2432,7 +2422,7 @@ func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg // ----- // CHECK-LABEL: func.func @torch.aten.diag_embed$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int -2 @@ -2464,10 +2454,9 @@ func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg // CHECK: %[[VAL_29:.*]] = tosa.scatter %[[VAL_18]], %[[VAL_28]], %[[VAL_16]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> // CHECK: %[[VAL_30:.*]] = tosa.const_shape {value = dense<[2, 3, 4, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_31:.*]] = tosa.reshape %[[VAL_29]], %[[VAL_30]] : (tensor<1x96x1xf32>, !tosa.shape<4>) -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_32:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_33:.*]] = tosa.transpose %[[VAL_31]], %[[VAL_32]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_34:.*]] = torch_c.from_builtin_tensor %[[VAL_33]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> -// CHECK: return %[[VAL_34]] : !torch.vtensor<[2,3,4,4],f32> +// CHECK: %[[VAL_32:.*]] = tosa.transpose %[[VAL_31]] {perms = array} : (tensor<2x3x4x4xf32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> +// CHECK: return %[[VAL_33]] : !torch.vtensor<[2,3,4,4],f32> // CHECK: } func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { %int0 = torch.constant.int 0 @@ -2640,7 +2629,7 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor // ----- // CHECK-LABEL: func.func @torch.aten.max_pool1d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.bool false // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 @@ -2652,16 +2641,14 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 64, 112, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_10]] : (tensor<1x64x112xf32>, !tosa.shape<4>) -> tensor<1x64x112x1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32> -// CHECK: %[[VAL_14:.*]] = tosa.max_pool2d %[[VAL_13]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_14]], %[[VAL_15]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[1, 64, 56]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<1x64x56x1xf32>, !tosa.shape<3>) -> tensor<1x64x56xf32> -// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,64,56],f32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x64x112x1xf32>) -> tensor<1x112x1x64xf32> +// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array} : (tensor<1x56x1x64xf32>) -> tensor<1x64x56x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 64, 56]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<1x64x56x1xf32>, !tosa.shape<3>) -> tensor<1x64x56xf32> +// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32> // CHECK: } func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { %false = torch.constant.bool false @@ -2679,7 +2666,7 @@ func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> // ----- // CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 @@ -2689,16 +2676,14 @@ func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_8]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_9]], %[[VAL_10]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32> -// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32> -// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> -// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,512,10],f32> +// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_9]] {perms = array} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]], %[[VAL_13]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32> // CHECK: } func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { %int1 = torch.constant.int 1 @@ -3407,7 +3392,7 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s // ----- // CHECK-LABEL: func.func @torch.aten.unfold$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 @@ -3433,10 +3418,9 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s // CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<1x24x1xf32>, !tosa.shape<2>) -> tensor<6x4xf32> // CHECK: %[[VAL_24:.*]] = tosa.const_shape {value = dense<[3, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_23]], %[[VAL_24]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<3x2x4xf32> -// CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_27:.*]] = tosa.transpose %[[VAL_25]], %[[VAL_26]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> -// CHECK: %[[VAL_28:.*]] = torch_c.from_builtin_tensor %[[VAL_27]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> -// CHECK: return %[[VAL_28]] : !torch.vtensor<[3,4,2],f32> +// CHECK: %[[VAL_26:.*]] = tosa.transpose %[[VAL_25]] {perms = array} : (tensor<3x2x4xf32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_27]] : !torch.vtensor<[3,4,2],f32> // CHECK: } func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { %int0 = torch.constant.int 0 @@ -3504,7 +3488,7 @@ func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // ----- // CHECK-LABEL: func.func @torch.aten.constant_pad_nd$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,20,20,4,4],f32> -> tensor<1x1x20x20x4x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.float 0xFFF0000000000000 // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 @@ -3512,9 +3496,11 @@ func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<12xindex>} : () -> !tosa.shape<12> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x20x20x4x4xf32>, !tosa.shape<12>, tensor) -> tensor<1x1x20x20x4x5xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x1x20x20x4x5xf32> -> !torch.vtensor<[1,1,20,20,4,5],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor, !tosa.shape<1>) -> tensor<1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_9]] : (tensor<1x1x20x20x4x4xf32>, !tosa.shape<12>, tensor<1xf32>) -> tensor<1x1x20x20x4x5xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<1x1x20x20x4x5xf32> -> !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[1,1,20,20,4,5],f32> // CHECK: } func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { %float-Inf = torch.constant.float 0xFFF0000000000000 @@ -3528,7 +3514,7 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // ----- // CHECK-LABEL: func.func @torch.aten.convolution$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,10,20],f32> -> tensor<5x2x10x20xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.bool false // CHECK: %[[VAL_3:.*]] = torch.constant.int 3 @@ -3540,17 +3526,15 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_12]] : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_4]], %[[VAL_12]] : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_19:.*]] = tosa.transpose %[[VAL_17]], %[[VAL_18]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> -// CHECK: %[[VAL_20:.*]] = tensor.cast %[[VAL_19]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<5x14x24x10xf32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[5,10,14,24],f32> // CHECK: } func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { %false = torch.constant.bool false @@ -3569,7 +3553,7 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // ----- // CHECK-LABEL: func.func @torch.aten.convolution$depthwise( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4,10,20],f32> -> tensor<5x4x10x20xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.bool false // CHECK: %[[VAL_3:.*]] = torch.constant.int 4 @@ -3582,20 +3566,17 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_13]] : (tensor<5x4x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x4xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_15]] : (tensor<4x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_18]], %[[VAL_12]], %[[VAL_19]], %[[VAL_20]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_21]], %[[VAL_22]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> -// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> -// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> -// CHECK: return %[[VAL_25]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.depthwise_conv2d %[[VAL_13]], %[[VAL_16]], %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_19]] {perms = array} : (tensor<5x5x10x4xf32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[5,4,5,10],f32> // CHECK: } func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { %false = torch.constant.bool false @@ -3613,3 +3594,134 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3 } // ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$zero_pad_with_sliced_input( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,128,28,28],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,56,56],f32> -> tensor<1x64x56x56xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense_resource : tensor<128x64x1x1xf32>}> : () -> tensor<128x64x1x1xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[1, 55, 56, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_17:.*]] = tosa.slice %[[VAL_13]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 55, 55, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : (tensor<1x55x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x55x64xf32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.conv2d %[[VAL_20]], %[[VAL_14]], %[[VAL_12]], %[[VAL_21]], %[[VAL_22]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> +// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_23]] {perms = array} : (tensor<1x28x28x128xf32>) -> tensor<1x128x28x28xf32> +// CHECK: %[[VAL_25:.*]] = tensor.cast %[[VAL_24]] : tensor<1x128x28x28xf32> to tensor<1x128x28x28xf32> +// CHECK: %[[VAL_26:.*]] = torch_c.from_builtin_tensor %[[VAL_25]] : tensor<1x128x28x28xf32> -> !torch.vtensor<[1,128,28,28],f32> +// CHECK: return %[[VAL_26]] : !torch.vtensor<[1,128,28,28],f32> +// CHECK: } +func.func @torch.aten.convolution$zero_pad_with_sliced_input(%arg0: !torch.vtensor<[1,64,56,56],f32>) -> !torch.vtensor<[1,128,28,28],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense_resource : tensor<128x64x1x1xf32>) : !torch.vtensor<[128,64,1,1],f32> + %none = torch.constant.none + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[1,64,56,56],f32>, !torch.vtensor<[128,64,1,1],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,128,28,28],f32> + return %5 : !torch.vtensor<[1,128,28,28],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,32,112,112],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,3,224,224],f32> -> tensor<1x3x224x224xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> +// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<1x112x112x32xf32>) -> tensor<1x32x112x112xf32> +// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<1x32x112x112xf32> to tensor<1x32x112x112xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x32x112x112xf32> -> !torch.vtensor<[1,32,112,112],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,32,112,112],f32> +// CHECK: } +func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input(%arg0: !torch.vtensor<[1,3,224,224],f32>) -> !torch.vtensor<[1,32,112,112],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense_resource : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32> + %none = torch.constant.none + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,32,112,112],f32> + return %5 : !torch.vtensor<[1,32,112,112],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,3,225,225],f32>) -> !torch.vtensor<[1,32,75,75],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,3,225,225],f32> -> tensor<1x3x225x225xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x224x3xf32> +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> +// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor<1x75x75x32xf32>) -> tensor<1x32x75x75xf32> +// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<1x32x75x75xf32> to tensor<1x32x75x75xf32> +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<1x32x75x75xf32> -> !torch.vtensor<[1,32,75,75],f32> +// CHECK: return %[[VAL_25]] : !torch.vtensor<[1,32,75,75],f32> +// CHECK: } +func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input(%arg0: !torch.vtensor<[1,3,225,225],f32>) -> !torch.vtensor<[1,32,75,75],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense_resource : tensor<32x3x3x3xf32>) : !torch.vtensor<[32,3,3,3],f32> + %none = torch.constant.none + %int3 = torch.constant.int 3 + %1 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[1,3,225,225],f32>, !torch.vtensor<[32,3,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,32,75,75],f32> + return %5 : !torch.vtensor<[1,32,75,75],f32> +} + +// -----