diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index dc8b5d431002..475e0ec407d4 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -661,7 +661,8 @@ class ConvertAtenUnflattenIntOp "Expected input type having sizes"); } int inputRank = inputTensorType.getSizes().size(); - int outputRank = outputTensorType.getSizes().size(); + auto outputSizes = outputTensorType.getSizes(); + int outputRank = outputSizes.size(); int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) @@ -675,23 +676,64 @@ class ConvertAtenUnflattenIntOp auto sizesOp = op.getSizes().getDefiningOp(); int numSizes = sizesOp.getNumOperands(); - SmallVector reassociations(inputRank); - if (inputRank > 0) { - for (int i = 0; i < dimInt; ++i) - reassociations[i].push_back(i); - - for (int i = 0; i < numSizes; ++i) - reassociations[dimInt].push_back(i + dimInt); - - for (int i = dimInt + numSizes; i < outputRank; ++i) - reassociations[i - numSizes + 1].push_back(i); + int64_t numDynamicReassocDims = 0; + for (int64_t i = dimInt; i < dimInt + numSizes; i++) { + if (outputSizes[i] == Torch::kUnknownSize) + numDynamicReassocDims++; } + SmallVector reassocSizes; + if (!getListConstructElements(op.getSizes(), reassocSizes) && + numDynamicReassocDims > 1) + return rewriter.notifyMatchFailure( + op, "Must be able to either infer expansion dims, or retrieve them " + "from list construct"); + auto expandTy = getTypeConverter()->convertType(outputTensorType); - auto expand = rewriter - .create( - loc, expandTy, adaptor.getSelf(), reassociations) - .getResult(); + Value expand; + // When there are less than two dynamic reassociation dims, this will lower + // to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. + // TODO: in the numDynamicReassocDims >= 2 case, lower to expand_shape with + // explicitly provided outputShape once + // https://github.com/iree-org/iree/issues/17760 is resolved. + if (numDynamicReassocDims < 2) { + SmallVector reassociations(inputRank); + if (inputRank > 0) { + for (int i = 0; i < dimInt; ++i) + reassociations[i].push_back(i); + for (int i = 0; i < numSizes; ++i) + reassociations[dimInt].push_back(i + dimInt); + for (int i = dimInt + numSizes; i < outputRank; ++i) + reassociations[i - numSizes + 1].push_back(i); + } + expand = rewriter + .create( + loc, expandTy, adaptor.getSelf(), reassociations) + .getResult(); + } else { + reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + reassocSizes); + SmallVector inputShape = + getTensorSizes(rewriter, loc, adaptor.getSelf()); + inputShape = castIndexVectorToInt64Vector(rewriter, loc, inputShape); + SmallVector outputShape(inputShape.begin(), + inputShape.begin() + dimInt); + if (inputRank > 0) { + for (int i = 0; i < numSizes; ++i) + outputShape.push_back(reassocSizes[i]); + for (int i = dimInt + numSizes; i < outputRank; ++i) + outputShape.push_back(inputShape[i - numSizes + 1]); + } + + RankedTensorType shapeType = RankedTensorType::get( + ArrayRef{outputRank}, rewriter.getIntegerType(64)); + Value shapeValue = + rewriter.create(loc, shapeType, outputShape); + expand = rewriter + .create(loc, expandTy, adaptor.getSelf(), + shapeValue) + .getResult(); + } rewriter.replaceOp(op, expand); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a0d7616a6a95..efb3bb327cb1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2195,17 +2195,6 @@ ONNX_XFAIL_SET = { # Failure - cast error "PermuteNegativeIndexModule_basic", - # Failure - expand multiple dynamic dims - "EmbeddingModuleF16_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - "IndexTensorSelectDimModule_basic", # Failure - incorrect numerics "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 3d265a308a0d..2da7c0b74fc2 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -281,3 +281,30 @@ func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3], %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[10,?,2,3],f32>, !torch.list -> !torch.vtensor<[2,5,?,6],f32> return %1 : !torch.vtensor<[2,5,?,6],f32> } + +// ----- + +// this is to check a path for unflatten.int with two dynamic reassociation dims +// the IR here is generated from the onnx.Gather conversion +// CHECK-LABEL: @gather_graph +// CHECK: %[[fromelt:.*]] = tensor.from_elements +// CHECK-SAME: tensor<3xi64> +// CHECK: %[[reshape:.*]] = tensor.reshape +// CHECK-SAME: (tensor, tensor<3xi64>) -> tensor +func.func @gather_graph(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?,3],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %int-1 = torch.constant.int -1 + %int5 = torch.constant.int 5 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.lt.Scalar %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],i1> + %1 = torch.aten.add.Scalar %arg1, %int5, %int1 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?],si64> + %2 = torch.aten.where.self %0, %1, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + %3 = torch.aten.size.int %2, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.aten.size.int %2, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %5 = torch.prim.ListConstruct %3, %4 : (!torch.int, !torch.int) -> !torch.list + %6 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %7 = torch.aten.view %2, %6 : !torch.vtensor<[?,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + %8 = torch.aten.index_select %arg0, %int0, %7 : !torch.vtensor<[5,3],f32>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,3],f32> + %9 = torch.aten.unflatten.int %8, %int0, %5 : !torch.vtensor<[?,3],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,3],f32> + return %9 : !torch.vtensor<[?,?,3],f32> +}