Skip to content

Commit 0706387

Browse files
authored
Merge pull request #25 from asmeurer/torch-fixes
Various fixes for the torch wrapper
2 parents e2203df + f85f427 commit 0706387

11 files changed

+399
-33
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: Array API Tests (NumPy 1.21)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-numpy:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: numpy
10+
package-version: '== 1.21.*'
11+
xfails-file-extra: '-1-21'
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
name: Array API Tests (NumPy)
1+
name: Array API Tests (NumPy Latest)
22

33
on: [push, pull_request]
44

55
jobs:
6-
array-api-tests-numpy:
6+
array-api-tests-numpy-1-21:
77
uses: ./.github/workflows/array-api-tests.yml
88
with:
99
package-name: numpy

.github/workflows/array-api-tests-torch.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Array API Tests (PyTorch)
1+
name: Array API Tests (PyTorch Latest)
22

33
on: [push, pull_request]
44

.github/workflows/array-api-tests.yml

+18-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,21 @@ on:
66
package-name:
77
required: true
88
type: string
9+
package-version:
10+
required: false
11+
type: string
12+
default: '>= 0'
913
pytest-extra-args:
1014
required: false
1115
type: string
16+
# This is not how I would prefer to implement this but it's the only way
17+
# that seems possible with GitHub Actions' limited expressions syntax
18+
xfails-file-extra:
19+
required: false
20+
type: string
21+
skips-file-extra:
22+
required: false
23+
type: string
1224

1325

1426
env:
@@ -41,11 +53,15 @@ jobs:
4153
with:
4254
python-version: ${{ matrix.python-version }}
4355
- name: Install dependencies
56+
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
57+
# to put this in the numpy 1.21 config file.
58+
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
4459
run: |
4560
python -m pip install --upgrade pip
46-
python -m pip install ${{ inputs.package-name }}
61+
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}'
4762
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
4863
- name: Run the array API testsuite (${{ inputs.package-name }})
64+
if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
4965
env:
5066
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }}
5167
# This enables the NEP 50 type promotion behavior (without it a lot of
@@ -54,4 +70,4 @@ jobs:
5470
run: |
5571
export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
5672
cd ${GITHUB_WORKSPACE}/array-api-tests
57-
pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-skips.txt array_api_tests/
73+
pytest array_api_tests/ --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.xfails-file-extra }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}${{ inputs.skips-file-extra}}-skips.txt ${PYTEST_ARGS}

CHANGELOG.md

+21
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
# 1.1.1 (2023-03-08)
2+
3+
## Minor Changes
4+
5+
- The minimum supported NumPy version is now 1.21. Fixed a few issues with
6+
NumPy 1.21 (with `unique_*` and `asarray`), although there are also a few
7+
known issues with this version (see the README).
8+
9+
- Add `api_version` to `get_namespace()`.
10+
11+
- `get_namespace()` now works correctly with `torch` tensors.
12+
13+
- `get_namespace()` now works correctly with `numpy.array_api` arrays.
14+
15+
- `get_namespace()` now raises `TypeError` instead of `ValueError`.
16+
17+
- Fix the `torch.std` wrapper.
18+
19+
- Add `torch` wrappers for `ones`, `empty`, and `zeros` so that `shape` can be
20+
passed as a keyword argument.
21+
122
# 1.1 (2023-02-24)
223

324
## Major Changes

README.md

+20
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,29 @@ specification:
141141
50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and
142142
https://github.com/numpy/numpy/issues/22341)
143143

144+
- `asarray()` does not support `copy=False`.
145+
144146
- Functions which are not wrapped may not have the same type annotations
145147
as the spec.
146148

147149
- Functions which are not wrapped may not use positional-only arguments.
148150

151+
The minimum supported NumPy version is 1.21. However, this older version of
152+
NumPy has a few issues:
153+
154+
- `unique_*` will not compare nans as unequal.
155+
- `finfo()` has no `smallest_normal`.
156+
- No `from_dlpack` or `__dlpack__`.
157+
- `argmax()` and `argmin()` do not have `keepdims`.
158+
- `qr()` doesn't support matrix stacks.
159+
- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not
160+
supported even in the latest NumPy).
161+
- Type promotion behavior will be value based for 0-D arrays (and there is no
162+
`NPY_PROMOTION_STATE=weak` to disable this).
163+
164+
If any of these are an issue, it is recommended to bump your minimum NumPy
165+
version.
166+
149167
### PyTorch
150168

151169
- Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the
@@ -190,6 +208,8 @@ specification:
190208
- As with NumPy, type annotations and positional-only arguments may not
191209
exactly match the spec for functions that are not wrapped at all.
192210

211+
The minimum supported PyTorch version is 1.13.
212+
193213
## Vendoring
194214

195215
This library supports vendoring as an installation method. To vendor the

array_api_compat/common/_aliases.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from typing import NamedTuple
1313
from types import ModuleType
14+
import inspect
1415

1516
from ._helpers import _check_device, _is_numpy_array, get_namespace
1617

@@ -161,13 +162,23 @@ class UniqueInverseResult(NamedTuple):
161162
inverse_indices: ndarray
162163

163164

165+
def _unique_kwargs(xp):
166+
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
167+
# trying to parse version numbers, just check if equal_nan is in the
168+
# signature.
169+
s = inspect.signature(xp.unique)
170+
if 'equal_nan' in s.parameters:
171+
return {'equal_nan': False}
172+
return {}
173+
164174
def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
175+
kwargs = _unique_kwargs(xp)
165176
values, indices, inverse_indices, counts = xp.unique(
166177
x,
167178
return_counts=True,
168179
return_index=True,
169180
return_inverse=True,
170-
equal_nan=False,
181+
**kwargs,
171182
)
172183
# np.unique() flattens inverse indices, but they need to share x's shape
173184
# See https://github.com/numpy/numpy/issues/20638
@@ -181,24 +192,26 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
181192

182193

183194
def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
195+
kwargs = _unique_kwargs(xp)
184196
res = xp.unique(
185197
x,
186198
return_counts=True,
187199
return_index=False,
188200
return_inverse=False,
189-
equal_nan=False,
201+
**kwargs
190202
)
191203

192204
return UniqueCountsResult(*res)
193205

194206

195207
def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
208+
kwargs = _unique_kwargs(xp)
196209
values, inverse_indices = xp.unique(
197210
x,
198211
return_counts=False,
199212
return_index=False,
200213
return_inverse=True,
201-
equal_nan=False,
214+
**kwargs,
202215
)
203216
# xp.unique() flattens inverse indices, but they need to share x's shape
204217
# See https://github.com/numpy/numpy/issues/20638
@@ -207,12 +220,13 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
207220

208221

209222
def unique_values(x: ndarray, /, xp) -> ndarray:
223+
kwargs = _unique_kwargs(xp)
210224
return xp.unique(
211225
x,
212226
return_counts=False,
213227
return_index=False,
214228
return_inverse=False,
215-
equal_nan=False,
229+
**kwargs,
216230
)
217231

218232
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
@@ -295,8 +309,13 @@ def _asarray(
295309
_check_device(xp, device)
296310
if _is_numpy_array(obj):
297311
import numpy as np
298-
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
299-
COPY_TRUE = (True, np._CopyMode.ALWAYS)
312+
if hasattr(np, '_CopyMode'):
313+
# Not present in older NumPys
314+
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
315+
COPY_TRUE = (True, np._CopyMode.ALWAYS)
316+
else:
317+
COPY_FALSE = (False,)
318+
COPY_TRUE = (True,)
300319
else:
301320
COPY_FALSE = (False,)
302321
COPY_TRUE = (True,)

array_api_compat/common/_helpers.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -49,33 +49,50 @@ def is_array_api_obj(x):
4949
or _is_torch_array(x) \
5050
or hasattr(x, '__array_namespace__')
5151

52-
def get_namespace(*xs, _use_compat=True):
52+
def _check_api_version(api_version):
53+
if api_version is not None and api_version != '2021.12':
54+
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
55+
56+
def get_namespace(*xs, api_version=None, _use_compat=True):
5357
"""
5458
Get the array API compatible namespace for the arrays `xs`.
5559
5660
`xs` should contain one or more arrays.
61+
62+
Typical usage is
63+
64+
def your_function(x, y):
65+
xp = array_api_compat.get_namespace(x, y)
66+
# Now use xp as the array library namespace
67+
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
68+
69+
api_version should be the newest version of the spec that you need support
70+
for (currently the compat library wrapped APIs only support v2021.12).
5771
"""
5872
namespaces = set()
5973
for x in xs:
6074
if isinstance(x, (tuple, list)):
6175
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
6276
elif hasattr(x, '__array_namespace__'):
63-
namespaces.add(x.__array_namespace__())
77+
namespaces.add(x.__array_namespace__(api_version=api_version))
6478
elif _is_numpy_array(x):
79+
_check_api_version(api_version)
6580
if _use_compat:
6681
from .. import numpy as numpy_namespace
6782
namespaces.add(numpy_namespace)
6883
else:
6984
import numpy as np
7085
namespaces.add(np)
7186
elif _is_cupy_array(x):
87+
_check_api_version(api_version)
7288
if _use_compat:
7389
from .. import cupy as cupy_namespace
7490
namespaces.add(cupy_namespace)
7591
else:
7692
import cupy as cp
7793
namespaces.add(cp)
7894
elif _is_torch_array(x):
95+
_check_api_version(api_version)
7996
if _use_compat:
8097
from .. import torch as torch_namespace
8198
namespaces.add(torch_namespace)
@@ -84,13 +101,13 @@ def get_namespace(*xs, _use_compat=True):
84101
namespaces.add(torch)
85102
else:
86103
# TODO: Support Python scalars?
87-
raise ValueError("The input is not a supported array type")
104+
raise TypeError("The input is not a supported array type")
88105

89106
if not namespaces:
90-
raise ValueError("Unrecognized array input")
107+
raise TypeError("Unrecognized array input")
91108

92109
if len(namespaces) != 1:
93-
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
110+
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
94111

95112
xp, = namespaces
96113

array_api_compat/torch/_aliases.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,10 @@ def std(x: array,
361361
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
362362
# implement it here for now.
363363

364-
# if isinstance(correction, float):
365-
# correction = int(correction)
364+
if isinstance(correction, float):
365+
_correction = int(correction)
366+
if correction != _correction:
367+
raise NotImplementedError("float correction in torch std() is not yet supported")
366368

367369
# https://github.com/pytorch/pytorch/issues/29137
368370
if axis == ():
@@ -372,10 +374,10 @@ def std(x: array,
372374
if axis is None:
373375
# torch doesn't support keepdims with axis=None
374376
# (https://github.com/pytorch/pytorch/issues/71209)
375-
res = torch.std(x, tuple(range(x.ndim)), correction=correction, **kwargs)
377+
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
376378
res = _axis_none_keepdims(res, x.ndim, keepdims)
377379
return res
378-
return torch.std(x, axis, correction=correction, keepdims=keepdims, **kwargs)
380+
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
379381

380382
def var(x: array,
381383
/,
@@ -519,6 +521,28 @@ def full(shape: Union[int, Tuple[int, ...]],
519521

520522
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
521523

524+
# ones, zeros, and empty do not accept shape as a keyword argument
525+
def ones(shape: Union[int, Tuple[int, ...]],
526+
*,
527+
dtype: Optional[Dtype] = None,
528+
device: Optional[Device] = None,
529+
**kwargs) -> array:
530+
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
531+
532+
def zeros(shape: Union[int, Tuple[int, ...]],
533+
*,
534+
dtype: Optional[Dtype] = None,
535+
device: Optional[Device] = None,
536+
**kwargs) -> array:
537+
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
538+
539+
def empty(shape: Union[int, Tuple[int, ...]],
540+
*,
541+
dtype: Optional[Dtype] = None,
542+
device: Optional[Device] = None,
543+
**kwargs) -> array:
544+
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
545+
522546
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
523547
def expand_dims(x: array, /, *, axis: int = 0) -> array:
524548
return torch.unsqueeze(x, axis)
@@ -585,7 +609,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
585609
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
586610
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
587611
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
588-
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
589-
'expand_dims', 'astype', 'broadcast_arrays', 'unique_all',
590-
'unique_counts', 'unique_inverse', 'unique_values',
612+
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
613+
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
614+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
591615
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']

0 commit comments

Comments
 (0)