Skip to content

Commit be4fa68

Browse files
committed
Handle torch versions that do not have uint dtypes
1 parent d3c4b3c commit be4fa68

File tree

1 file changed

+87
-87
lines changed

1 file changed

+87
-87
lines changed

array_api_compat/torch/_info.py

+87-87
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,7 @@
77
more details.
88
99
"""
10-
from torch import (
11-
asarray,
12-
get_default_dtype,
13-
device,
14-
empty,
15-
bool,
16-
int8,
17-
int16,
18-
int32,
19-
int64,
20-
uint8,
21-
uint16,
22-
uint32,
23-
uint64,
24-
float32,
25-
float64,
26-
complex64,
27-
complex128,
28-
)
10+
import torch
2911

3012
from functools import cache
3113

@@ -130,7 +112,7 @@ def default_device(self):
130112
'cpu'
131113
132114
"""
133-
return device("cpu")
115+
return torch.device("cpu")
134116

135117
def default_dtypes(self, *, device=None):
136118
"""
@@ -165,80 +147,32 @@ def default_dtypes(self, *, device=None):
165147
'indexing': torch.int64}
166148
167149
"""
168-
default_floating = get_default_dtype()
169-
default_complex = complex64 if default_floating == float32 else complex128
170-
default_integral = asarray(0, device=device).dtype
150+
default_floating = torch.get_default_dtype()
151+
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128
152+
default_integral = torch.asarray(0, device=device).dtype
171153
return {
172154
"real floating": default_floating,
173155
"complex floating": default_complex,
174156
"integral": default_integral,
175157
"indexing": default_integral,
176158
}
177159

178-
@cache
179-
def dtypes(self, *, device=None, kind=None):
180-
"""
181-
The array API data types supported by PyTorch.
182-
183-
Note that this function only returns data types that are defined by
184-
the array API.
185-
186-
Parameters
187-
----------
188-
device : str, optional
189-
The device to get the data types for.
190-
kind : str or tuple of str, optional
191-
The kind of data types to return. If ``None``, all data types are
192-
returned. If a string, only data types of that kind are returned.
193-
If a tuple, a dictionary containing the union of the given kinds
194-
is returned. The following kinds are supported:
195-
196-
- ``'bool'``: boolean data types (i.e., ``bool``).
197-
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
198-
``int16``, ``int32``, ``int64``).
199-
- ``'unsigned integer'``: unsigned integer data types (i.e.,
200-
``uint8``, ``uint16``, ``uint32``, ``uint64``).
201-
- ``'integral'``: integer data types. Shorthand for ``('signed
202-
integer', 'unsigned integer')``.
203-
- ``'real floating'``: real-valued floating-point data types
204-
(i.e., ``float32``, ``float64``).
205-
- ``'complex floating'``: complex floating-point data types (i.e.,
206-
``complex64``, ``complex128``).
207-
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
208-
'real floating', 'complex floating')``.
209-
210-
Returns
211-
-------
212-
dtypes : dict
213-
A dictionary mapping the names of data types to the corresponding
214-
PyTorch data types.
215-
216-
See Also
217-
--------
218-
__array_namespace_info__.capabilities,
219-
__array_namespace_info__.default_device,
220-
__array_namespace_info__.default_dtypes,
221-
__array_namespace_info__.devices
222-
223-
Examples
224-
--------
225-
>>> info = np.__array_namespace_info__()
226-
>>> info.dtypes(kind='signed integer')
227-
{'int8': numpy.int8,
228-
'int16': numpy.int16,
229-
'int32': numpy.int32,
230-
'int64': numpy.int64}
231-
232-
"""
233-
res = self._dtypes(kind)
234-
for k, v in res.copy().items():
235-
try:
236-
empty((0,), dtype=v, device=device)
237-
except:
238-
del res[k]
239-
return res
240160

241161
def _dtypes(self, kind):
162+
bool = torch.bool
163+
int8 = torch.int8
164+
int16 = torch.int16
165+
int32 = torch.int32
166+
int64 = torch.int64
167+
uint8 = getattr(torch, "uint8", None)
168+
uint16 = getattr(torch, "uint16", None)
169+
uint32 = getattr(torch, "uint32", None)
170+
uint64 = getattr(torch, "uint64", None)
171+
float32 = torch.float32
172+
float64 = torch.float64
173+
complex64 = torch.complex64
174+
complex128 = torch.complex128
175+
242176
if kind is None:
243177
return {
244178
"bool": bool,
@@ -314,6 +248,72 @@ def _dtypes(self, kind):
314248
return res
315249
raise ValueError(f"unsupported kind: {kind!r}")
316250

251+
@cache
252+
def dtypes(self, *, device=None, kind=None):
253+
"""
254+
The array API data types supported by PyTorch.
255+
256+
Note that this function only returns data types that are defined by
257+
the array API.
258+
259+
Parameters
260+
----------
261+
device : str, optional
262+
The device to get the data types for.
263+
kind : str or tuple of str, optional
264+
The kind of data types to return. If ``None``, all data types are
265+
returned. If a string, only data types of that kind are returned.
266+
If a tuple, a dictionary containing the union of the given kinds
267+
is returned. The following kinds are supported:
268+
269+
- ``'bool'``: boolean data types (i.e., ``bool``).
270+
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
271+
``int16``, ``int32``, ``int64``).
272+
- ``'unsigned integer'``: unsigned integer data types (i.e.,
273+
``uint8``, ``uint16``, ``uint32``, ``uint64``).
274+
- ``'integral'``: integer data types. Shorthand for ``('signed
275+
integer', 'unsigned integer')``.
276+
- ``'real floating'``: real-valued floating-point data types
277+
(i.e., ``float32``, ``float64``).
278+
- ``'complex floating'``: complex floating-point data types (i.e.,
279+
``complex64``, ``complex128``).
280+
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
281+
'real floating', 'complex floating')``.
282+
283+
Returns
284+
-------
285+
dtypes : dict
286+
A dictionary mapping the names of data types to the corresponding
287+
PyTorch data types.
288+
289+
See Also
290+
--------
291+
__array_namespace_info__.capabilities,
292+
__array_namespace_info__.default_device,
293+
__array_namespace_info__.default_dtypes,
294+
__array_namespace_info__.devices
295+
296+
Examples
297+
--------
298+
>>> info = np.__array_namespace_info__()
299+
>>> info.dtypes(kind='signed integer')
300+
{'int8': numpy.int8,
301+
'int16': numpy.int16,
302+
'int32': numpy.int32,
303+
'int64': numpy.int64}
304+
305+
"""
306+
res = self._dtypes(kind)
307+
for k, v in res.copy().items():
308+
if v is None:
309+
del res[k]
310+
continue
311+
try:
312+
torch.empty((0,), dtype=v, device=device)
313+
except:
314+
del res[k]
315+
return res
316+
317317
@cache
318318
def devices(self):
319319
"""
@@ -343,7 +343,7 @@ def devices(self):
343343
# message of torch.device to get the list of all possible types of
344344
# device:
345345
try:
346-
device('notadevice')
346+
torch.device('notadevice')
347347
except RuntimeError as e:
348348
# The error message is something like:
349349
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
@@ -358,7 +358,7 @@ def devices(self):
358358
i = 0
359359
while True:
360360
try:
361-
a = empty((0,), device=device(device_name, index=i))
361+
a = torch.empty((0,), device=torch.device(device_name, index=i))
362362
if a.device in devices:
363363
break
364364
devices.append(a.device)

0 commit comments

Comments
 (0)