7
7
more details.
8
8
9
9
"""
10
- from torch import (
11
- asarray ,
12
- get_default_dtype ,
13
- device ,
14
- empty ,
15
- bool ,
16
- int8 ,
17
- int16 ,
18
- int32 ,
19
- int64 ,
20
- uint8 ,
21
- uint16 ,
22
- uint32 ,
23
- uint64 ,
24
- float32 ,
25
- float64 ,
26
- complex64 ,
27
- complex128 ,
28
- )
10
+ import torch
29
11
30
12
from functools import cache
31
13
@@ -130,7 +112,7 @@ def default_device(self):
130
112
'cpu'
131
113
132
114
"""
133
- return device ("cpu" )
115
+ return torch . device ("cpu" )
134
116
135
117
def default_dtypes (self , * , device = None ):
136
118
"""
@@ -165,80 +147,32 @@ def default_dtypes(self, *, device=None):
165
147
'indexing': torch.int64}
166
148
167
149
"""
168
- default_floating = get_default_dtype ()
169
- default_complex = complex64 if default_floating == float32 else complex128
170
- default_integral = asarray (0 , device = device ).dtype
150
+ default_floating = torch . get_default_dtype ()
151
+ default_complex = torch . complex64 if default_floating == torch . float32 else torch . complex128
152
+ default_integral = torch . asarray (0 , device = device ).dtype
171
153
return {
172
154
"real floating" : default_floating ,
173
155
"complex floating" : default_complex ,
174
156
"integral" : default_integral ,
175
157
"indexing" : default_integral ,
176
158
}
177
159
178
- @cache
179
- def dtypes (self , * , device = None , kind = None ):
180
- """
181
- The array API data types supported by PyTorch.
182
-
183
- Note that this function only returns data types that are defined by
184
- the array API.
185
-
186
- Parameters
187
- ----------
188
- device : str, optional
189
- The device to get the data types for.
190
- kind : str or tuple of str, optional
191
- The kind of data types to return. If ``None``, all data types are
192
- returned. If a string, only data types of that kind are returned.
193
- If a tuple, a dictionary containing the union of the given kinds
194
- is returned. The following kinds are supported:
195
-
196
- - ``'bool'``: boolean data types (i.e., ``bool``).
197
- - ``'signed integer'``: signed integer data types (i.e., ``int8``,
198
- ``int16``, ``int32``, ``int64``).
199
- - ``'unsigned integer'``: unsigned integer data types (i.e.,
200
- ``uint8``, ``uint16``, ``uint32``, ``uint64``).
201
- - ``'integral'``: integer data types. Shorthand for ``('signed
202
- integer', 'unsigned integer')``.
203
- - ``'real floating'``: real-valued floating-point data types
204
- (i.e., ``float32``, ``float64``).
205
- - ``'complex floating'``: complex floating-point data types (i.e.,
206
- ``complex64``, ``complex128``).
207
- - ``'numeric'``: numeric data types. Shorthand for ``('integral',
208
- 'real floating', 'complex floating')``.
209
-
210
- Returns
211
- -------
212
- dtypes : dict
213
- A dictionary mapping the names of data types to the corresponding
214
- PyTorch data types.
215
-
216
- See Also
217
- --------
218
- __array_namespace_info__.capabilities,
219
- __array_namespace_info__.default_device,
220
- __array_namespace_info__.default_dtypes,
221
- __array_namespace_info__.devices
222
-
223
- Examples
224
- --------
225
- >>> info = np.__array_namespace_info__()
226
- >>> info.dtypes(kind='signed integer')
227
- {'int8': numpy.int8,
228
- 'int16': numpy.int16,
229
- 'int32': numpy.int32,
230
- 'int64': numpy.int64}
231
-
232
- """
233
- res = self ._dtypes (kind )
234
- for k , v in res .copy ().items ():
235
- try :
236
- empty ((0 ,), dtype = v , device = device )
237
- except :
238
- del res [k ]
239
- return res
240
160
241
161
def _dtypes (self , kind ):
162
+ bool = torch .bool
163
+ int8 = torch .int8
164
+ int16 = torch .int16
165
+ int32 = torch .int32
166
+ int64 = torch .int64
167
+ uint8 = getattr (torch , "uint8" , None )
168
+ uint16 = getattr (torch , "uint16" , None )
169
+ uint32 = getattr (torch , "uint32" , None )
170
+ uint64 = getattr (torch , "uint64" , None )
171
+ float32 = torch .float32
172
+ float64 = torch .float64
173
+ complex64 = torch .complex64
174
+ complex128 = torch .complex128
175
+
242
176
if kind is None :
243
177
return {
244
178
"bool" : bool ,
@@ -314,6 +248,72 @@ def _dtypes(self, kind):
314
248
return res
315
249
raise ValueError (f"unsupported kind: { kind !r} " )
316
250
251
+ @cache
252
+ def dtypes (self , * , device = None , kind = None ):
253
+ """
254
+ The array API data types supported by PyTorch.
255
+
256
+ Note that this function only returns data types that are defined by
257
+ the array API.
258
+
259
+ Parameters
260
+ ----------
261
+ device : str, optional
262
+ The device to get the data types for.
263
+ kind : str or tuple of str, optional
264
+ The kind of data types to return. If ``None``, all data types are
265
+ returned. If a string, only data types of that kind are returned.
266
+ If a tuple, a dictionary containing the union of the given kinds
267
+ is returned. The following kinds are supported:
268
+
269
+ - ``'bool'``: boolean data types (i.e., ``bool``).
270
+ - ``'signed integer'``: signed integer data types (i.e., ``int8``,
271
+ ``int16``, ``int32``, ``int64``).
272
+ - ``'unsigned integer'``: unsigned integer data types (i.e.,
273
+ ``uint8``, ``uint16``, ``uint32``, ``uint64``).
274
+ - ``'integral'``: integer data types. Shorthand for ``('signed
275
+ integer', 'unsigned integer')``.
276
+ - ``'real floating'``: real-valued floating-point data types
277
+ (i.e., ``float32``, ``float64``).
278
+ - ``'complex floating'``: complex floating-point data types (i.e.,
279
+ ``complex64``, ``complex128``).
280
+ - ``'numeric'``: numeric data types. Shorthand for ``('integral',
281
+ 'real floating', 'complex floating')``.
282
+
283
+ Returns
284
+ -------
285
+ dtypes : dict
286
+ A dictionary mapping the names of data types to the corresponding
287
+ PyTorch data types.
288
+
289
+ See Also
290
+ --------
291
+ __array_namespace_info__.capabilities,
292
+ __array_namespace_info__.default_device,
293
+ __array_namespace_info__.default_dtypes,
294
+ __array_namespace_info__.devices
295
+
296
+ Examples
297
+ --------
298
+ >>> info = np.__array_namespace_info__()
299
+ >>> info.dtypes(kind='signed integer')
300
+ {'int8': numpy.int8,
301
+ 'int16': numpy.int16,
302
+ 'int32': numpy.int32,
303
+ 'int64': numpy.int64}
304
+
305
+ """
306
+ res = self ._dtypes (kind )
307
+ for k , v in res .copy ().items ():
308
+ if v is None :
309
+ del res [k ]
310
+ continue
311
+ try :
312
+ torch .empty ((0 ,), dtype = v , device = device )
313
+ except :
314
+ del res [k ]
315
+ return res
316
+
317
317
@cache
318
318
def devices (self ):
319
319
"""
@@ -343,7 +343,7 @@ def devices(self):
343
343
# message of torch.device to get the list of all possible types of
344
344
# device:
345
345
try :
346
- device ('notadevice' )
346
+ torch . device ('notadevice' )
347
347
except RuntimeError as e :
348
348
# The error message is something like:
349
349
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
@@ -358,7 +358,7 @@ def devices(self):
358
358
i = 0
359
359
while True :
360
360
try :
361
- a = empty ((0 ,), device = device (device_name , index = i ))
361
+ a = torch . empty ((0 ,), device = torch . device (device_name , index = i ))
362
362
if a .device in devices :
363
363
break
364
364
devices .append (a .device )
0 commit comments