Skip to content

Commit d3c4b3c

Browse files
committed
Add inspection namespace for torch
Some of these things have to be inspected manually, and I'm not completely certain everything here is correct.
1 parent 8e3f0b6 commit d3c4b3c

File tree

2 files changed

+378
-6
lines changed

2 files changed

+378
-6
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
_aliases_clip, unstack as _aliases_unstack,)
99
from .._internal import get_xp
1010

11+
from ._info import __array_namespace_info__
12+
1113
import torch
1214

1315
from typing import TYPE_CHECKING
@@ -724,12 +726,13 @@ def sign(x: array, /) -> array:
724726
return out
725727

726728

727-
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
728-
'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
729-
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
730-
'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
731-
'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal',
732-
'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort',
729+
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
730+
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
731+
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
732+
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
733+
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
734+
'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
735+
'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort',
733736
'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat',
734737
'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where',
735738
'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros',

array_api_compat/torch/_info.py

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
"""
2+
Array API Inspection namespace
3+
4+
This is the namespace for inspection functions as defined by the array API
5+
standard. See
6+
https://data-apis.org/array-api/latest/API_specification/inspection.html for
7+
more details.
8+
9+
"""
10+
from torch import (
11+
asarray,
12+
get_default_dtype,
13+
device,
14+
empty,
15+
bool,
16+
int8,
17+
int16,
18+
int32,
19+
int64,
20+
uint8,
21+
uint16,
22+
uint32,
23+
uint64,
24+
float32,
25+
float64,
26+
complex64,
27+
complex128,
28+
)
29+
30+
from functools import cache
31+
32+
class __array_namespace_info__:
33+
"""
34+
Get the array API inspection namespace for PyTorch.
35+
36+
The array API inspection namespace defines the following functions:
37+
38+
- capabilities()
39+
- default_device()
40+
- default_dtypes()
41+
- dtypes()
42+
- devices()
43+
44+
See
45+
https://data-apis.org/array-api/latest/API_specification/inspection.html
46+
for more details.
47+
48+
Returns
49+
-------
50+
info : ModuleType
51+
The array API inspection namespace for PyTorch.
52+
53+
Examples
54+
--------
55+
>>> info = np.__array_namespace_info__()
56+
>>> info.default_dtypes()
57+
{'real floating': numpy.float64,
58+
'complex floating': numpy.complex128,
59+
'integral': numpy.int64,
60+
'indexing': numpy.int64}
61+
62+
"""
63+
64+
__module__ = 'torch'
65+
66+
def capabilities(self):
67+
"""
68+
Return a dictionary of array API library capabilities.
69+
70+
The resulting dictionary has the following keys:
71+
72+
- **"boolean indexing"**: boolean indicating whether an array library
73+
supports boolean indexing. Always ``True`` for PyTorch.
74+
75+
- **"data-dependent shapes"**: boolean indicating whether an array
76+
library supports data-dependent output shapes. Always ``True`` for
77+
PyTorch.
78+
79+
See
80+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
81+
for more details.
82+
83+
See Also
84+
--------
85+
__array_namespace_info__.default_device,
86+
__array_namespace_info__.default_dtypes,
87+
__array_namespace_info__.dtypes,
88+
__array_namespace_info__.devices
89+
90+
Returns
91+
-------
92+
capabilities : dict
93+
A dictionary of array API library capabilities.
94+
95+
Examples
96+
--------
97+
>>> info = np.__array_namespace_info__()
98+
>>> info.capabilities()
99+
{'boolean indexing': True,
100+
'data-dependent shapes': True}
101+
102+
"""
103+
return {
104+
"boolean indexing": True,
105+
"data-dependent shapes": True,
106+
# 'max rank' will be part of the 2024.12 standard
107+
# "max rank": 64,
108+
}
109+
110+
def default_device(self):
111+
"""
112+
The default device used for new PyTorch arrays.
113+
114+
See Also
115+
--------
116+
__array_namespace_info__.capabilities,
117+
__array_namespace_info__.default_dtypes,
118+
__array_namespace_info__.dtypes,
119+
__array_namespace_info__.devices
120+
121+
Returns
122+
-------
123+
device : str
124+
The default device used for new PyTorch arrays.
125+
126+
Examples
127+
--------
128+
>>> info = np.__array_namespace_info__()
129+
>>> info.default_device()
130+
'cpu'
131+
132+
"""
133+
return device("cpu")
134+
135+
def default_dtypes(self, *, device=None):
136+
"""
137+
The default data types used for new PyTorch arrays.
138+
139+
Parameters
140+
----------
141+
device : str, optional
142+
The device to get the default data types for. For PyTorch, only
143+
``'cpu'`` is allowed.
144+
145+
Returns
146+
-------
147+
dtypes : dict
148+
A dictionary describing the default data types used for new PyTorch
149+
arrays.
150+
151+
See Also
152+
--------
153+
__array_namespace_info__.capabilities,
154+
__array_namespace_info__.default_device,
155+
__array_namespace_info__.dtypes,
156+
__array_namespace_info__.devices
157+
158+
Examples
159+
--------
160+
>>> info = np.__array_namespace_info__()
161+
>>> info.default_dtypes()
162+
{'real floating': torch.float32,
163+
'complex floating': torch.complex64,
164+
'integral': torch.int64,
165+
'indexing': torch.int64}
166+
167+
"""
168+
default_floating = get_default_dtype()
169+
default_complex = complex64 if default_floating == float32 else complex128
170+
default_integral = asarray(0, device=device).dtype
171+
return {
172+
"real floating": default_floating,
173+
"complex floating": default_complex,
174+
"integral": default_integral,
175+
"indexing": default_integral,
176+
}
177+
178+
@cache
179+
def dtypes(self, *, device=None, kind=None):
180+
"""
181+
The array API data types supported by PyTorch.
182+
183+
Note that this function only returns data types that are defined by
184+
the array API.
185+
186+
Parameters
187+
----------
188+
device : str, optional
189+
The device to get the data types for.
190+
kind : str or tuple of str, optional
191+
The kind of data types to return. If ``None``, all data types are
192+
returned. If a string, only data types of that kind are returned.
193+
If a tuple, a dictionary containing the union of the given kinds
194+
is returned. The following kinds are supported:
195+
196+
- ``'bool'``: boolean data types (i.e., ``bool``).
197+
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
198+
``int16``, ``int32``, ``int64``).
199+
- ``'unsigned integer'``: unsigned integer data types (i.e.,
200+
``uint8``, ``uint16``, ``uint32``, ``uint64``).
201+
- ``'integral'``: integer data types. Shorthand for ``('signed
202+
integer', 'unsigned integer')``.
203+
- ``'real floating'``: real-valued floating-point data types
204+
(i.e., ``float32``, ``float64``).
205+
- ``'complex floating'``: complex floating-point data types (i.e.,
206+
``complex64``, ``complex128``).
207+
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
208+
'real floating', 'complex floating')``.
209+
210+
Returns
211+
-------
212+
dtypes : dict
213+
A dictionary mapping the names of data types to the corresponding
214+
PyTorch data types.
215+
216+
See Also
217+
--------
218+
__array_namespace_info__.capabilities,
219+
__array_namespace_info__.default_device,
220+
__array_namespace_info__.default_dtypes,
221+
__array_namespace_info__.devices
222+
223+
Examples
224+
--------
225+
>>> info = np.__array_namespace_info__()
226+
>>> info.dtypes(kind='signed integer')
227+
{'int8': numpy.int8,
228+
'int16': numpy.int16,
229+
'int32': numpy.int32,
230+
'int64': numpy.int64}
231+
232+
"""
233+
res = self._dtypes(kind)
234+
for k, v in res.copy().items():
235+
try:
236+
empty((0,), dtype=v, device=device)
237+
except:
238+
del res[k]
239+
return res
240+
241+
def _dtypes(self, kind):
242+
if kind is None:
243+
return {
244+
"bool": bool,
245+
"int8": int8,
246+
"int16": int16,
247+
"int32": int32,
248+
"int64": int64,
249+
"uint8": uint8,
250+
"uint16": uint16,
251+
"uint32": uint32,
252+
"uint64": uint64,
253+
"float32": float32,
254+
"float64": float64,
255+
"complex64": complex64,
256+
"complex128": complex128,
257+
}
258+
if kind == "bool":
259+
return {"bool": bool}
260+
if kind == "signed integer":
261+
return {
262+
"int8": int8,
263+
"int16": int16,
264+
"int32": int32,
265+
"int64": int64,
266+
}
267+
if kind == "unsigned integer":
268+
return {
269+
"uint8": uint8,
270+
"uint16": uint16,
271+
"uint32": uint32,
272+
"uint64": uint64,
273+
}
274+
if kind == "integral":
275+
return {
276+
"int8": int8,
277+
"int16": int16,
278+
"int32": int32,
279+
"int64": int64,
280+
"uint8": uint8,
281+
"uint16": uint16,
282+
"uint32": uint32,
283+
"uint64": uint64,
284+
}
285+
if kind == "real floating":
286+
return {
287+
"float32": float32,
288+
"float64": float64,
289+
}
290+
if kind == "complex floating":
291+
return {
292+
"complex64": complex64,
293+
"complex128": complex128,
294+
}
295+
if kind == "numeric":
296+
return {
297+
"int8": int8,
298+
"int16": int16,
299+
"int32": int32,
300+
"int64": int64,
301+
"uint8": uint8,
302+
"uint16": uint16,
303+
"uint32": uint32,
304+
"uint64": uint64,
305+
"float32": float32,
306+
"float64": float64,
307+
"complex64": complex64,
308+
"complex128": complex128,
309+
}
310+
if isinstance(kind, tuple):
311+
res = {}
312+
for k in kind:
313+
res.update(self.dtypes(kind=k))
314+
return res
315+
raise ValueError(f"unsupported kind: {kind!r}")
316+
317+
@cache
318+
def devices(self):
319+
"""
320+
The devices supported by PyTorch.
321+
322+
Returns
323+
-------
324+
devices : list of str
325+
The devices supported by PyTorch.
326+
327+
See Also
328+
--------
329+
__array_namespace_info__.capabilities,
330+
__array_namespace_info__.default_device,
331+
__array_namespace_info__.default_dtypes,
332+
__array_namespace_info__.dtypes
333+
334+
Examples
335+
--------
336+
>>> info = np.__array_namespace_info__()
337+
>>> info.devices()
338+
[device(type='cpu'), device(type='mps', index=0), device(type='meta')]
339+
340+
"""
341+
# Torch doesn't have a straightforward way to get the list of all
342+
# currently supported devices. To do this, we first parse the error
343+
# message of torch.device to get the list of all possible types of
344+
# device:
345+
try:
346+
device('notadevice')
347+
except RuntimeError as e:
348+
# The error message is something like:
349+
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
350+
devices_names = e.args[0].split('Expected one of ')[1].split(' device type')[0].split(', ')
351+
352+
# Next we need to check for different indices for different devices.
353+
# device(device_name, index=index) doesn't actually check if the
354+
# device name or index is valid. We have to try to create a tensor
355+
# with it (which is why this function is cached).
356+
devices = []
357+
for device_name in devices_names:
358+
i = 0
359+
while True:
360+
try:
361+
a = empty((0,), device=device(device_name, index=i))
362+
if a.device in devices:
363+
break
364+
devices.append(a.device)
365+
except:
366+
break
367+
i += 1
368+
369+
return devices

0 commit comments

Comments
 (0)