Skip to content

Commit d231fa3

Browse files
authored
Merge pull request #554 from Xilinx/bump_to_af8514c9
[AutoBump] Merge with af8514c (Jan 27) (160)
2 parents 71af168 + dcf1ef2 commit d231fa3

File tree

11 files changed

+239
-16
lines changed

11 files changed

+239
-16
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13877,6 +13877,34 @@ def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
1387713877
}];
1387813878
}
1387913879

13880+
def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", [
13881+
AllowsTypeRefinement,
13882+
HasValueSemantics,
13883+
ReadOnly
13884+
]> {
13885+
let summary = "Generated op for `aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()`";
13886+
let arguments = (ins
13887+
AnyTorchTensorType:$a,
13888+
AnyTorchOptionalListOfTorchIntType:$size,
13889+
AnyTorchOptionalListOfTorchIntType:$stride,
13890+
AnyTorchOptionalIntType:$dtype,
13891+
AnyTorchOptionalDeviceType:$device,
13892+
AnyTorchOptionalIntType:$layout
13893+
);
13894+
let results = (outs
13895+
);
13896+
let hasCustomAssemblyFormat = 1;
13897+
let extraClassDefinition = [{
13898+
ParseResult Aten_AssertTensorMetadataOp::parse(OpAsmParser &parser, OperationState &result) {
13899+
return parseDefaultTorchOp(parser, result, 6, 0);
13900+
}
13901+
void Aten_AssertTensorMetadataOp::print(OpAsmPrinter &printer) {
13902+
printDefaultTorchOp(printer, *this, 6, 0);
13903+
}
13904+
}];
13905+
let hasFolder = 1;
13906+
}
13907+
1388013908
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
1388113909
AllowsTypeRefinement,
1388213910
ReadOnly

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,7 +2346,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
23462346
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
23472347
unsigned inputRank = inputShape.size();
23482348
// only handle 2D, 3D and 5D pooling cases
2349-
if (inputRank > 5 or inputRank < 3) {
2349+
if (inputRank > 5 || inputRank < 3) {
23502350
return failure();
23512351
}
23522352
if (!resultType || !resultType.hasSizes()) {
@@ -2454,7 +2454,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
24542454
"Unimplemented: unranked tensor");
24552455
unsigned rank = *maybeRank;
24562456
// only 1D, 2D and 3D LpPool is supported.
2457-
if (rank > 5 or rank < 3) {
2457+
if (rank > 5 || rank < 3) {
24582458
return failure();
24592459
}
24602460

@@ -3697,11 +3697,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
36973697
binder.tensorResultType(resultType))
36983698
return failure();
36993699

3700-
// TODO: Add support for non-zero center_point_box value.
3701-
if (centerPointBox != 0)
3700+
if (centerPointBox != 0 && centerPointBox != 1)
37023701
return rewriter.notifyMatchFailure(
3703-
binder.op, "unimplemented: expected center_point_box "
3704-
"attribute value to be 0");
3702+
binder.op, "expected center_point_box attribute to be 0 or 1");
37053703

37063704
// TODO: Support multiple batches and classes
37073705
// Squeeze the boxes and scores tensor.
@@ -3727,6 +3725,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37273725
"failed to squeeze scores tensor");
37283726
boxes = squeezedBoxes.value();
37293727
scores = squeezedScores.value();
3728+
if (centerPointBox == 1) {
3729+
// When center_point_box is 1, the box data is supplied as
3730+
// [[x_center, y_center, width, height], ...]. Slice it to
3731+
// [[x_center, y_center], ...] and [[width, height], ...],
3732+
// calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate
3733+
// to [[x1, y1, x2, y2], ...]
3734+
auto boxesTensorType =
3735+
dyn_cast<Torch::ValueTensorType>(boxes.getType());
3736+
Value const0 = rewriter.create<Torch::ConstantIntOp>(
3737+
loc, rewriter.getI64IntegerAttr(0));
3738+
Value const1 = rewriter.create<Torch::ConstantIntOp>(
3739+
loc, rewriter.getI64IntegerAttr(1));
3740+
Value const2 = rewriter.create<Torch::ConstantIntOp>(
3741+
loc, rewriter.getI64IntegerAttr(2));
3742+
Value const4 = rewriter.create<Torch::ConstantIntOp>(
3743+
loc, rewriter.getI64IntegerAttr(4));
3744+
Value const2F = rewriter.create<Torch::ConstantFloatOp>(
3745+
loc, rewriter.getF64FloatAttr(2.0));
3746+
3747+
// extract scaled ranges for regions of interest
3748+
auto sliceShape = SmallVector<int64_t>{Torch::kUnknownSize, 2};
3749+
auto sliceTensorType = rewriter.getType<Torch::ValueTensorType>(
3750+
sliceShape, boxesTensorType.getDtype());
3751+
Value centers = rewriter.create<Torch::AtenSliceTensorOp>(
3752+
loc, sliceTensorType, boxes, const1, const0, const2, const1);
3753+
Value sizes = rewriter.create<Torch::AtenSliceTensorOp>(
3754+
loc, sliceTensorType, boxes, const1, const2, const4, const1);
3755+
Value halfSizes = rewriter.create<Torch::AtenDivScalarOp>(
3756+
loc, sizes.getType(), sizes, const2F);
3757+
Value x1y1s = rewriter.create<Torch::AtenSubTensorOp>(
3758+
loc, centers.getType(), centers, halfSizes, const1);
3759+
Value x2y2s = rewriter.create<Torch::AtenAddTensorOp>(
3760+
loc, centers.getType(), centers, halfSizes, const1);
3761+
3762+
Type listElemType = boxesTensorType.getWithSizesAndDtype(
3763+
/*optionalSizes=*/std::nullopt,
3764+
/*optionalDtype=*/nullptr);
3765+
Type listType = Torch::ListType::get(listElemType);
3766+
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
3767+
loc, listType, SmallVector<Value>{x1y1s, x2y2s});
3768+
boxes = rewriter.create<Torch::AtenCatOp>(loc, boxesTensorType,
3769+
tensorList, const1);
3770+
}
37303771

37313772
// TODO: Support score_threshold input
37323773
// Filter out the boxes if the score < score_threshold

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5462,6 +5462,54 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
54625462
return getA();
54635463
}
54645464

5465+
//===----------------------------------------------------------------------===//
5466+
// Aten_AssertTensorMetadataOp
5467+
//===----------------------------------------------------------------------===//
5468+
5469+
LogicalResult Aten_AssertTensorMetadataOp::fold(
5470+
FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) {
5471+
Value input = getA();
5472+
auto inputType = cast<BaseTensorType>(input.getType());
5473+
if (!inputType.hasDtype() || !inputType.hasSizes())
5474+
return failure();
5475+
5476+
// TODO: Add checks for stride, device, and layout when we can extract that
5477+
// information from the torch tensor. For now, we can only get the shape and
5478+
// dtype info from the tensor hence adding checks for them.
5479+
5480+
// convert size to a list of integers.
5481+
SmallVector<int64_t> size;
5482+
if (!isa<Torch::NoneType>(getSize().getType())) {
5483+
if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) {
5484+
return emitOpError("expected dtype to be a constant int");
5485+
}
5486+
if (!llvm::all_of(llvm::zip(inputType.getSizes(), size),
5487+
[](const auto &pair) {
5488+
return std::get<0>(pair) == std::get<1>(pair);
5489+
}))
5490+
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
5491+
"the sizes do not match");
5492+
}
5493+
5494+
// convert dtype to an integer.
5495+
int64_t dtype;
5496+
if (!isa<Torch::NoneType>(getDtype().getType())) {
5497+
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) {
5498+
return emitOpError("expected dtype to be a constant int");
5499+
}
5500+
FailureOr<Type> inputDtype =
5501+
getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype);
5502+
if (failed(inputDtype))
5503+
return failure();
5504+
if (inputType.getDtype() != inputDtype)
5505+
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
5506+
"the dtype does not match");
5507+
}
5508+
5509+
getOperation()->erase();
5510+
return success();
5511+
}
5512+
54655513
//===----------------------------------------------------------------------===//
54665514
// AtenMaxPoolWithIndicesOp
54675515
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11950,7 +11950,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1195011950
" return %0#1 : !torch.int\n"
1195111951
" }\n"
1195211952
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
11953+
" %none = torch.constant.none\n"
11954+
" %str = torch.constant.str \"AssertionError: \"\n"
11955+
" %int0 = torch.constant.int 0\n"
1195311956
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11957+
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11958+
" torch.prim.If %1 -> () {\n"
11959+
" torch.prim.If.yield\n"
11960+
" } else {\n"
11961+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11962+
" torch.prim.If.yield\n"
11963+
" }\n"
1195411964
" return %0#1 : !torch.int\n"
1195511965
" }\n"
1195611966
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
@@ -11962,11 +11972,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1196211972
" return %0#1 : !torch.int\n"
1196311973
" }\n"
1196411974
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
11975+
" %none = torch.constant.none\n"
11976+
" %str = torch.constant.str \"AssertionError: \"\n"
11977+
" %int0 = torch.constant.int 0\n"
1196511978
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11979+
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11980+
" torch.prim.If %1 -> () {\n"
11981+
" torch.prim.If.yield\n"
11982+
" } else {\n"
11983+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11984+
" torch.prim.If.yield\n"
11985+
" }\n"
1196611986
" return %0#1 : !torch.int\n"
1196711987
" }\n"
1196811988
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
11989+
" %none = torch.constant.none\n"
11990+
" %str = torch.constant.str \"AssertionError: \"\n"
11991+
" %int0 = torch.constant.int 0\n"
1196911992
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11993+
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11994+
" torch.prim.If %1 -> () {\n"
11995+
" torch.prim.If.yield\n"
11996+
" } else {\n"
11997+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11998+
" torch.prim.If.yield\n"
11999+
" }\n"
1197012000
" return %0#1 : !torch.int\n"
1197112001
" }\n"
1197212002
" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9780,16 +9780,16 @@ class DecomposeAtenNllLossForwardOp
97809780
auto targetSizes = targetType.getSizes();
97819781
int64_t selfRank = selfSizes.size();
97829782
int64_t targetRank = targetSizes.size();
9783-
if (selfRank <= 0 or selfRank > 2) {
9783+
if (selfRank <= 0 || selfRank > 2) {
97849784
return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D");
97859785
}
97869786
if (targetRank > 1) {
97879787
return rewriter.notifyMatchFailure(op,
97889788
"target tensor shoule be 0D or 1D!");
97899789
}
97909790

9791-
if (selfRank != 1 or targetRank != 0) {
9792-
if (!(selfSizes[0] == kUnknownSize and targetSizes[0] == kUnknownSize) and
9791+
if (selfRank != 1 || targetRank != 0) {
9792+
if (!(selfSizes[0] == kUnknownSize && targetSizes[0] == kUnknownSize) &&
97939793
selfSizes[0] != targetSizes[0]) {
97949794
return rewriter.notifyMatchFailure(
97959795
op,
@@ -9907,7 +9907,7 @@ class DecomposeAtenNllLossForwardOp
99079907
zeroTensor);
99089908

99099909
Value totalWeight;
9910-
if (reduction == 0 and selfRank > 1) {
9910+
if (reduction == 0 && selfRank > 1) {
99119911
auto zeroFloat =
99129912
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
99139913
Value twSize = rewriter.create<PrimListConstructOp>(

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,9 +2922,10 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto
29222922
self_rank, self_dtype = self_rank_dtype
29232923
return self_dtype
29242924

2925-
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
2925+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8}))
29262926
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
29272927
self_rank, self_dtype = self_rank_dtype
2928+
assert self_dtype != torch.uint8
29282929
return self_dtype
29292930

29302931
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
@@ -2937,14 +2938,16 @@ def aten〇adaptive_avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_
29372938
self_rank, self_dtype = self_rank_dtype
29382939
return self_dtype
29392940

2940-
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
2941+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2], error_types={torch.uint8}))
29412942
def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
29422943
self_rank, self_dtype = self_rank_dtype
2944+
assert self_dtype != torch.uint8
29432945
return self_dtype
29442946

2945-
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2]))
2947+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2], error_types={torch.uint8}))
29462948
def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
29472949
self_rank, self_dtype = self_rank_dtype
2950+
assert self_dtype != torch.uint8
29482951
return self_dtype
29492952

29502953
# @check_dtype_function(_check_tensors_with_the_same_dtype(

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,10 @@ def emit_with_mutating_variants(key, **kwargs):
10191019
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
10201020
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
10211021
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
1022+
emit(
1023+
"aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()",
1024+
has_folder=True,
1025+
)
10221026
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
10231027
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
10241028
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")

pytorch-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3f159d635772fa2a8fd352d96b95100d885f8169
1+
37626ee0e6ff5dc1d38664690bd2ff6c790aab0c

pytorch-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# release page, and we use this page as an additional source for the wheels.
55
-f https://xilinx.github.io/torch-mlir/package-index/
66
--pre
7-
torch==2.6.0.dev20241216
7+
torch==2.7.0.dev20250120

0 commit comments

Comments
 (0)