Skip to content

Commit 9fe035d

Browse files
committed
Add dependency to array-api-compat
1 parent 4360355 commit 9fe035d

File tree

13 files changed

+264
-263
lines changed

13 files changed

+264
-263
lines changed

.github/workflows/test-vendor.yml

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: Test vendoring support
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
push:
7+
branches:
8+
- main
9+
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.ref }}
12+
cancel-in-progress: true
13+
14+
env:
15+
# Many color libraries just need this to be set to any value, but at least
16+
# one distinguishes color depth, where "3" -> "256-bit color".
17+
FORCE_COLOR: 3
18+
19+
jobs:
20+
pre-commit-and-lint:
21+
name: Format
22+
runs-on: ubuntu-latest
23+
steps:
24+
- name: Checkout array-api-extra
25+
uses: actions/checkout@v4
26+
with:
27+
path: array-api-extra
28+
29+
- name: Checkout array-api-compat
30+
uses: actions/checkout@v4
31+
with:
32+
repository: data-apis/array-api-compat
33+
path: array-api-compat
34+
35+
- name: Vendor array-api-extra
36+
run: |
37+
cp -a array-api-compat/array_api_compat array-api-extra/vendor_demo/
38+
cp -a array-api-extra/src/array_api_extra array-api-extra/vendor_demo/
39+
40+
- name: Install Python
41+
uses: actions/setup-python@v5
42+
with:
43+
python-version: "3.x"
44+
45+
- name: Install Pixi
46+
uses: prefix-dev/[email protected]
47+
with:
48+
pixi-version: v0.37.0
49+
cache: true
50+
51+
- name: Test package
52+
run: pixi run tests-vendor

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ ENV/
114114
env.bak/
115115
venv.bak/
116116

117-
# Spyder project settings
117+
# IDE project settings
118+
.idea/
118119
.spyderproject
119120
.spyproject
121+
.vscode/
120122

121123
# Rope project settings
122124
.ropeproject

docs/index.md

+38-3
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,46 @@ a specific version, or vendor the library inside your own.
5656
## Vendoring
5757

5858
To vendor the library, clone
59-
[the repository](https://github.com/data-apis/array-api-extra) and copy it into
60-
the appropriate place in your library, like:
59+
[the array-api-extra repository](https://github.com/data-apis/array-api-extra)
60+
and copy it into the appropriate place in your library, like:
6161

6262
```
63-
cp -R array-api-extra/ mylib/vendored/array_api_extra
63+
cp -a array-api-extra/src/array_api_extra mylib/vendored/
64+
```
65+
66+
`array-api-extra` depends on `array-api-compat`. You may either add a dependency
67+
in your own project to `array-api-compat` or vendor it too:
68+
69+
1. Clone
70+
[the array-api-compat repository](https://github.com/data-apis/array-api-compat)
71+
and copy it next to your vendored array-api-extra:
72+
73+
```
74+
cp -a array-api-compat/array_api_compat mylib/vendored/
75+
```
76+
77+
2. Create a new hook file which array-api-extra will use instead of the
78+
top-level `array-api-compat` if present:
79+
80+
```
81+
echo 'from mylib.vendored.array_api_compat import *' > mylib/vendored/_array_api_compat_vendor.py
82+
```
83+
84+
This also allows overriding `array-api-compat` functions if you so wish. E.g.
85+
your `mylib/vendored/_array_api_compat_vendor.py` could look like this:
86+
87+
```python
88+
from mylib.vendored.array_api_compat import *
89+
from mylib.vendored.array_api_compat import array_namespace as _array_namespace_orig
90+
91+
92+
def array_namespace(*xs, **kwargs):
93+
import mylib
94+
95+
if any(isinstance(x, mylib.MyArray) for x in xs):
96+
return mylib
97+
else:
98+
return _array_namespace_orig(*xs, **kwargs)
6499
```
65100

66101
(usage)=

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = []
29+
dependencies = ["array-api-compat"]
3030

3131
[project.optional-dependencies]
3232
tests = [
@@ -96,6 +96,7 @@ numpy = "*"
9696
[tool.pixi.feature.tests.tasks]
9797
tests = { cmd = "pytest" }
9898
tests-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --durations=20" }
99+
tests-vendor = { cmd = "pytest vendor_tests" }
99100

100101
[tool.pixi.feature.docs.dependencies]
101102
sphinx = ">=7.0"

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
array-api-compat

src/array_api_extra/_funcs.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._lib._typing import Array, ModuleType
88

99
from ._lib import _utils
10+
from ._lib._compat import array_namespace
1011

1112
__all__ = [
1213
"atleast_nd",
@@ -19,7 +20,7 @@
1920
]
2021

2122

22-
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
23+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
2324
"""
2425
Recursively expand the dimension of an array to at least `ndim`.
2526
@@ -28,8 +29,8 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
2829
x : array
2930
ndim : int
3031
The minimum number of dimensions for the result.
31-
xp : array_namespace
32-
The standard-compatible namespace for `x`.
32+
xp : array_namespace, optional
33+
The standard-compatible namespace for `x`. Default: infer
3334
3435
Returns
3536
-------
@@ -53,13 +54,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
5354
True
5455
5556
"""
57+
if xp is None:
58+
xp = array_namespace(x)
59+
5660
if x.ndim < ndim:
5761
x = xp.expand_dims(x, axis=0)
5862
x = atleast_nd(x, ndim=ndim, xp=xp)
5963
return x
6064

6165

62-
def cov(m: Array, /, *, xp: ModuleType) -> Array:
66+
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
6367
"""
6468
Estimate a covariance matrix.
6569
@@ -77,8 +81,8 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
7781
A 1-D or 2-D array containing multiple variables and observations.
7882
Each row of `m` represents a variable, and each column a single
7983
observation of all those variables.
80-
xp : array_namespace
81-
The standard-compatible namespace for `m`.
84+
xp : array_namespace, optional
85+
The standard-compatible namespace for `m`. Default: infer
8286
8387
Returns
8488
-------
@@ -125,6 +129,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
125129
Array(2.14413333, dtype=array_api_strict.float64)
126130
127131
"""
132+
if xp is None:
133+
xp = array_namespace(m)
134+
128135
m = xp.asarray(m, copy=True)
129136
dtype = (
130137
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
@@ -150,7 +157,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
150157
return xp.squeeze(c, axis=axes)
151158

152159

153-
def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
160+
def create_diagonal(
161+
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
162+
) -> Array:
154163
"""
155164
Construct a diagonal array.
156165
@@ -162,8 +171,8 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
162171
Offset from the leading diagonal (default is ``0``).
163172
Use positive ints for diagonals above the leading diagonal,
164173
and negative ints for diagonals below the leading diagonal.
165-
xp : array_namespace
166-
The standard-compatible namespace for `x`.
174+
xp : array_namespace, optional
175+
The standard-compatible namespace for `x`. Default: infer
167176
168177
Returns
169178
-------
@@ -189,6 +198,9 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
189198
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
190199
191200
"""
201+
if xp is None:
202+
xp = array_namespace(x)
203+
192204
if x.ndim != 1:
193205
err_msg = "`x` must be 1-dimensional."
194206
raise ValueError(err_msg)
@@ -200,7 +212,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
200212

201213

202214
def expand_dims(
203-
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
215+
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
204216
) -> Array:
205217
"""
206218
Expand the shape of an array.
@@ -220,8 +232,8 @@ def expand_dims(
220232
given by a positive index could also be referred to by a negative index -
221233
that will also result in an error).
222234
Default: ``(0,)``.
223-
xp : array_namespace
224-
The standard-compatible namespace for `a`.
235+
xp : array_namespace, optional
236+
The standard-compatible namespace for `a`. Default: infer
225237
226238
Returns
227239
-------
@@ -265,6 +277,9 @@ def expand_dims(
265277
[2]]], dtype=array_api_strict.int64)
266278
267279
"""
280+
if xp is None:
281+
xp = array_namespace(a)
282+
268283
if not isinstance(axis, tuple):
269284
axis = (axis,)
270285
ndim = a.ndim + len(axis)
@@ -282,7 +297,7 @@ def expand_dims(
282297
return a
283298

284299

285-
def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
300+
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
286301
"""
287302
Kronecker product of two arrays.
288303
@@ -294,8 +309,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
294309
Parameters
295310
----------
296311
a, b : array
297-
xp : array_namespace
298-
The standard-compatible namespace for `a` and `b`.
312+
xp : array_namespace, optional
313+
The standard-compatible namespace for `a` and `b`. Default: infer
299314
300315
Returns
301316
-------
@@ -357,6 +372,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
357372
Array(True, dtype=array_api_strict.bool)
358373
359374
"""
375+
if xp is None:
376+
xp = array_namespace(a, b)
360377

361378
b = xp.asarray(b)
362379
singletons = (1,) * (b.ndim - a.ndim)
@@ -390,7 +407,12 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
390407

391408

392409
def setdiff1d(
393-
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
410+
x1: Array,
411+
x2: Array,
412+
/,
413+
*,
414+
assume_unique: bool = False,
415+
xp: ModuleType | None = None,
394416
) -> Array:
395417
"""
396418
Find the set difference of two arrays.
@@ -406,8 +428,8 @@ def setdiff1d(
406428
assume_unique : bool
407429
If ``True``, the input arrays are both assumed to be unique, which
408430
can speed up the calculation. Default is ``False``.
409-
xp : array_namespace
410-
The standard-compatible namespace for `x1` and `x2`.
431+
xp : array_namespace, optional
432+
The standard-compatible namespace for `x1` and `x2`. Default: infer
411433
412434
Returns
413435
-------
@@ -427,6 +449,8 @@ def setdiff1d(
427449
Array([1, 2], dtype=array_api_strict.int64)
428450
429451
"""
452+
if xp is None:
453+
xp = array_namespace(x1, x2)
430454

431455
if assume_unique:
432456
x1 = xp.reshape(x1, (-1,))
@@ -436,7 +460,7 @@ def setdiff1d(
436460
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
437461

438462

439-
def sinc(x: Array, /, *, xp: ModuleType) -> Array:
463+
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
440464
r"""
441465
Return the normalized sinc function.
442466
@@ -456,8 +480,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
456480
x : array
457481
Array (possibly multi-dimensional) of values for which to calculate
458482
``sinc(x)``. Must have a real floating point dtype.
459-
xp : array_namespace
460-
The standard-compatible namespace for `x`.
483+
xp : array_namespace, optional
484+
The standard-compatible namespace for `x`. Default: infer
461485
462486
Returns
463487
-------
@@ -511,6 +535,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
511535
-3.89817183e-17], dtype=array_api_strict.float64)
512536
513537
"""
538+
if xp is None:
539+
xp = array_namespace(x)
540+
514541
if not xp.isdtype(x.dtype, "real floating"):
515542
err_msg = "`x` must have a real floating data type."
516543
raise ValueError(err_msg)

0 commit comments

Comments
 (0)