Skip to content

Commit

Permalink
Merge pull request #554 from Xilinx/bump_to_af8514c9
Browse files Browse the repository at this point in the history
[AutoBump] Merge with af8514c (Jan 27) (160)
  • Loading branch information
mgehre-amd authored Feb 14, 2025
2 parents 71af168 + dcf1ef2 commit d231fa3
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 16 deletions.
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13877,6 +13877,34 @@ def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
}];
}

def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()`";
let arguments = (ins
AnyTorchTensorType:$a,
AnyTorchOptionalListOfTorchIntType:$size,
AnyTorchOptionalListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalIntType:$layout
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AssertTensorMetadataOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 0);
}
void Aten_AssertTensorMetadataOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 0);
}
}];
let hasFolder = 1;
}

def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
AllowsTypeRefinement,
ReadOnly
Expand Down
53 changes: 47 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,7 +2346,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
unsigned inputRank = inputShape.size();
// only handle 2D, 3D and 5D pooling cases
if (inputRank > 5 or inputRank < 3) {
if (inputRank > 5 || inputRank < 3) {
return failure();
}
if (!resultType || !resultType.hasSizes()) {
Expand Down Expand Up @@ -2454,7 +2454,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
// only 1D, 2D and 3D LpPool is supported.
if (rank > 5 or rank < 3) {
if (rank > 5 || rank < 3) {
return failure();
}

Expand Down Expand Up @@ -3697,11 +3697,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType))
return failure();

// TODO: Add support for non-zero center_point_box value.
if (centerPointBox != 0)
if (centerPointBox != 0 && centerPointBox != 1)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");
binder.op, "expected center_point_box attribute to be 0 or 1");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
Expand All @@ -3727,6 +3725,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"failed to squeeze scores tensor");
boxes = squeezedBoxes.value();
scores = squeezedScores.value();
if (centerPointBox == 1) {
// When center_point_box is 1, the box data is supplied as
// [[x_center, y_center, width, height], ...]. Slice it to
// [[x_center, y_center], ...] and [[width, height], ...],
// calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate
// to [[x1, y1, x2, y2], ...]
auto boxesTensorType =
dyn_cast<Torch::ValueTensorType>(boxes.getType());
Value const0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value const1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value const2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value const4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));
Value const2F = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));

// extract scaled ranges for regions of interest
auto sliceShape = SmallVector<int64_t>{Torch::kUnknownSize, 2};
auto sliceTensorType = rewriter.getType<Torch::ValueTensorType>(
sliceShape, boxesTensorType.getDtype());
Value centers = rewriter.create<Torch::AtenSliceTensorOp>(
loc, sliceTensorType, boxes, const1, const0, const2, const1);
Value sizes = rewriter.create<Torch::AtenSliceTensorOp>(
loc, sliceTensorType, boxes, const1, const2, const4, const1);
Value halfSizes = rewriter.create<Torch::AtenDivScalarOp>(
loc, sizes.getType(), sizes, const2F);
Value x1y1s = rewriter.create<Torch::AtenSubTensorOp>(
loc, centers.getType(), centers, halfSizes, const1);
Value x2y2s = rewriter.create<Torch::AtenAddTensorOp>(
loc, centers.getType(), centers, halfSizes, const1);

Type listElemType = boxesTensorType.getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, SmallVector<Value>{x1y1s, x2y2s});
boxes = rewriter.create<Torch::AtenCatOp>(loc, boxesTensorType,
tensorList, const1);
}

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
Expand Down
48 changes: 48 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5462,6 +5462,54 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
return getA();
}

//===----------------------------------------------------------------------===//
// Aten_AssertTensorMetadataOp
//===----------------------------------------------------------------------===//

LogicalResult Aten_AssertTensorMetadataOp::fold(
FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) {
Value input = getA();
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.hasSizes())
return failure();

// TODO: Add checks for stride, device, and layout when we can extract that
// information from the torch tensor. For now, we can only get the shape and
// dtype info from the tensor hence adding checks for them.

// convert size to a list of integers.
SmallVector<int64_t> size;
if (!isa<Torch::NoneType>(getSize().getType())) {
if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) {
return emitOpError("expected dtype to be a constant int");
}
if (!llvm::all_of(llvm::zip(inputType.getSizes(), size),
[](const auto &pair) {
return std::get<0>(pair) == std::get<1>(pair);
}))
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
"the sizes do not match");
}

// convert dtype to an integer.
int64_t dtype;
if (!isa<Torch::NoneType>(getDtype().getType())) {
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) {
return emitOpError("expected dtype to be a constant int");
}
FailureOr<Type> inputDtype =
getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype);
if (failed(inputDtype))
return failure();
if (inputType.getDtype() != inputDtype)
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
"the dtype does not match");
}

getOperation()->erase();
return success();
}

//===----------------------------------------------------------------------===//
// AtenMaxPoolWithIndicesOp
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11950,7 +11950,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
Expand All @@ -11962,11 +11972,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9780,16 +9780,16 @@ class DecomposeAtenNllLossForwardOp
auto targetSizes = targetType.getSizes();
int64_t selfRank = selfSizes.size();
int64_t targetRank = targetSizes.size();
if (selfRank <= 0 or selfRank > 2) {
if (selfRank <= 0 || selfRank > 2) {
return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D");
}
if (targetRank > 1) {
return rewriter.notifyMatchFailure(op,
"target tensor shoule be 0D or 1D!");
}

if (selfRank != 1 or targetRank != 0) {
if (!(selfSizes[0] == kUnknownSize and targetSizes[0] == kUnknownSize) and
if (selfRank != 1 || targetRank != 0) {
if (!(selfSizes[0] == kUnknownSize && targetSizes[0] == kUnknownSize) &&
selfSizes[0] != targetSizes[0]) {
return rewriter.notifyMatchFailure(
op,
Expand Down Expand Up @@ -9907,7 +9907,7 @@ class DecomposeAtenNllLossForwardOp
zeroTensor);

Value totalWeight;
if (reduction == 0 and selfRank > 1) {
if (reduction == 0 && selfRank > 1) {
auto zeroFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value twSize = rewriter.create<PrimListConstructOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2922,9 +2922,10 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8}))
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.uint8
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
Expand All @@ -2937,14 +2938,16 @@ def aten〇adaptive_avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2], error_types={torch.uint8}))
def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.uint8
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2]))
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2], error_types={torch.uint8}))
def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.uint8
return self_dtype

# @check_dtype_function(_check_tensors_with_the_same_dtype(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,10 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
emit(
"aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()",
has_folder=True,
)
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3f159d635772fa2a8fd352d96b95100d885f8169
37626ee0e6ff5dc1d38664690bd2ff6c790aab0c
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# release page, and we use this page as an additional source for the wheels.
-f https://xilinx.github.io/torch-mlir/package-index/
--pre
torch==2.6.0.dev20241216
torch==2.7.0.dev20250120
Loading

0 comments on commit d231fa3

Please sign in to comment.