Skip to content

Commit 4f0b79c

Browse files
build: manually update PyTorch version (llvm#3977)
This commit sets the PyTorch and TorchVision version to nightly release 2025-01-20. This commit also adds the aten::_assert_tensor_metadata op by adding a folder for the op. Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 2564d7a commit 4f0b79c

File tree

8 files changed

+119
-6
lines changed

8 files changed

+119
-6
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13826,6 +13826,34 @@ def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
1382613826
}];
1382713827
}
1382813828

13829+
def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", [
13830+
AllowsTypeRefinement,
13831+
HasValueSemantics,
13832+
ReadOnly
13833+
]> {
13834+
let summary = "Generated op for `aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()`";
13835+
let arguments = (ins
13836+
AnyTorchTensorType:$a,
13837+
AnyTorchOptionalListOfTorchIntType:$size,
13838+
AnyTorchOptionalListOfTorchIntType:$stride,
13839+
AnyTorchOptionalIntType:$dtype,
13840+
AnyTorchOptionalDeviceType:$device,
13841+
AnyTorchOptionalIntType:$layout
13842+
);
13843+
let results = (outs
13844+
);
13845+
let hasCustomAssemblyFormat = 1;
13846+
let extraClassDefinition = [{
13847+
ParseResult Aten_AssertTensorMetadataOp::parse(OpAsmParser &parser, OperationState &result) {
13848+
return parseDefaultTorchOp(parser, result, 6, 0);
13849+
}
13850+
void Aten_AssertTensorMetadataOp::print(OpAsmPrinter &printer) {
13851+
printDefaultTorchOp(printer, *this, 6, 0);
13852+
}
13853+
}];
13854+
let hasFolder = 1;
13855+
}
13856+
1382913857
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
1383013858
AllowsTypeRefinement,
1383113859
ReadOnly

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5378,6 +5378,54 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
53785378
return getA();
53795379
}
53805380

5381+
//===----------------------------------------------------------------------===//
5382+
// Aten_AssertTensorMetadataOp
5383+
//===----------------------------------------------------------------------===//
5384+
5385+
LogicalResult Aten_AssertTensorMetadataOp::fold(
5386+
FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) {
5387+
Value input = getA();
5388+
auto inputType = cast<BaseTensorType>(input.getType());
5389+
if (!inputType.hasDtype() || !inputType.hasSizes())
5390+
return failure();
5391+
5392+
// TODO: Add checks for stride, device, and layout when we can extract that
5393+
// information from the torch tensor. For now, we can only get the shape and
5394+
// dtype info from the tensor hence adding checks for them.
5395+
5396+
// convert size to a list of integers.
5397+
SmallVector<int64_t> size;
5398+
if (!isa<Torch::NoneType>(getSize().getType())) {
5399+
if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) {
5400+
return emitOpError("expected dtype to be a constant int");
5401+
}
5402+
if (!llvm::all_of(llvm::zip(inputType.getSizes(), size),
5403+
[](const auto &pair) {
5404+
return std::get<0>(pair) == std::get<1>(pair);
5405+
}))
5406+
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
5407+
"the sizes do not match");
5408+
}
5409+
5410+
// convert dtype to an integer.
5411+
int64_t dtype;
5412+
if (!isa<Torch::NoneType>(getDtype().getType())) {
5413+
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) {
5414+
return emitOpError("expected dtype to be a constant int");
5415+
}
5416+
FailureOr<Type> inputDtype =
5417+
getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype);
5418+
if (failed(inputDtype))
5419+
return failure();
5420+
if (inputType.getDtype() != inputDtype)
5421+
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
5422+
"the dtype does not match");
5423+
}
5424+
5425+
getOperation()->erase();
5426+
return success();
5427+
}
5428+
53815429
//===----------------------------------------------------------------------===//
53825430
// AtenMaxPoolWithIndicesOp
53835431
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11916,7 +11916,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1191611916
" return %0#1 : !torch.int\n"
1191711917
" }\n"
1191811918
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
11919+
" %none = torch.constant.none\n"
11920+
" %str = torch.constant.str \"AssertionError: \"\n"
11921+
" %int0 = torch.constant.int 0\n"
1191911922
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11923+
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11924+
" torch.prim.If %1 -> () {\n"
11925+
" torch.prim.If.yield\n"
11926+
" } else {\n"
11927+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11928+
" torch.prim.If.yield\n"
11929+
" }\n"
1192011930
" return %0#1 : !torch.int\n"
1192111931
" }\n"
1192211932
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
@@ -11928,11 +11938,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1192811938
" return %0#1 : !torch.int\n"
1192911939
" }\n"
1193011940
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
11941+
" %none = torch.constant.none\n"
11942+
" %str = torch.constant.str \"AssertionError: \"\n"
11943+
" %int0 = torch.constant.int 0\n"
1193111944
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11945+
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11946+
" torch.prim.If %1 -> () {\n"
11947+
" torch.prim.If.yield\n"
11948+
" } else {\n"
11949+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11950+
" torch.prim.If.yield\n"
11951+
" }\n"
1193211952
" return %0#1 : !torch.int\n"
1193311953
" }\n"
1193411954
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
11955+
" %none = torch.constant.none\n"
11956+
" %str = torch.constant.str \"AssertionError: \"\n"
11957+
" %int0 = torch.constant.int 0\n"
1193511958
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11959+
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
11960+
" torch.prim.If %1 -> () {\n"
11961+
" torch.prim.If.yield\n"
11962+
" } else {\n"
11963+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11964+
" torch.prim.If.yield\n"
11965+
" }\n"
1193611966
" return %0#1 : !torch.int\n"
1193711967
" }\n"
1193811968
" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2901,9 +2901,10 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto
29012901
self_rank, self_dtype = self_rank_dtype
29022902
return self_dtype
29032903

2904-
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
2904+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8}))
29052905
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
29062906
self_rank, self_dtype = self_rank_dtype
2907+
assert self_dtype != torch.uint8
29072908
return self_dtype
29082909

29092910
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
@@ -2916,14 +2917,16 @@ def aten〇adaptive_avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_
29162917
self_rank, self_dtype = self_rank_dtype
29172918
return self_dtype
29182919

2919-
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
2920+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2], error_types={torch.uint8}))
29202921
def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
29212922
self_rank, self_dtype = self_rank_dtype
2923+
assert self_dtype != torch.uint8
29222924
return self_dtype
29232925

2924-
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2]))
2926+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2], error_types={torch.uint8}))
29252927
def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
29262928
self_rank, self_dtype = self_rank_dtype
2929+
assert self_dtype != torch.uint8
29272930
return self_dtype
29282931

29292932
# @check_dtype_function(_check_tensors_with_the_same_dtype(

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,10 @@ def emit_with_mutating_variants(key, **kwargs):
10101010
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
10111011
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
10121012
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
1013+
emit(
1014+
"aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()",
1015+
has_folder=True,
1016+
)
10131017
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
10141018
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
10151019
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")

pytorch-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3f159d635772fa2a8fd352d96b95100d885f8169
1+
37626ee0e6ff5dc1d38664690bd2ff6c790aab0c

pytorch-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torch/
22
--pre
3-
torch==2.6.0.dev20241216
3+
torch==2.7.0.dev20250120

torchvision-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
22
--pre
3-
torchvision==0.22.0.dev20241216
3+
torchvision==0.22.0.dev20250120

0 commit comments

Comments
 (0)