Skip to content

Commit 2eafb97

Browse files
authored
Merge pull request data-apis#244 from crusaderky/torch_uint
ENH: More uint types for torch
2 parents 6e897a1 + c787bea commit 2eafb97

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

array_api_compat/torch/_aliases.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
torch.int32,
3131
torch.int64,
3232
}
33+
try:
34+
# torch >=2.3
35+
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
36+
except AttributeError:
37+
pass
38+
3339

3440
_array_api_dtypes = {
3541
torch.bool,

torch-xfails.txt

-12
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,12 @@ array_api_tests/test_array_object.py::test_getitem
88
array_api_tests/test_array_object.py::test_setitem
99
# Masking doesn't suport 0 dimensions in the mask
1010
array_api_tests/test_array_object.py::test_getitem_masking
11-
# torch doesn't have uint dtypes other than uint8
12-
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint16)]
13-
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint32)]
14-
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint64)]
15-
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)]
16-
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)]
17-
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)]
1811

1912
# Overflow error from large inputs
2013
array_api_tests/test_creation_functions.py::test_arange
2114
# pytorch linspace bug (should be fixed in torch 2.0)
2215
array_api_tests/test_creation_functions.py::test_linspace
2316

24-
# torch doesn't have higher uint dtypes
25-
array_api_tests/test_data_type_functions.py::test_iinfo[uint16]
26-
array_api_tests/test_data_type_functions.py::test_iinfo[uint32]
27-
array_api_tests/test_data_type_functions.py::test_iinfo[uint64]
28-
2917
# We cannot wrap the tensor object
3018
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
3119
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]

0 commit comments

Comments
 (0)