Skip to content

Commit 825b68c

Browse files
committed
New function at()
1 parent dc9fcf0 commit 825b68c

File tree

12 files changed

+7566
-1190
lines changed

12 files changed

+7566
-1190
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
strategy:
4949
fail-fast: false
5050
matrix:
51-
environment: [ci-py310, ci-py313]
51+
environment: [ci-py310, ci-py313, ci-backends]
5252
runs-on: [ubuntu-latest]
5353

5454
steps:

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
56+
"jax": ("https://jax.readthedocs.io/en/latest", None),
5657
}
5758

5859
nitpick_ignore = [

pixi.lock

+7,026-1,174
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+40-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.1.1"]
29+
# dependencies = ["array-api-compat>=1.10.0"] # Do not release
3030

3131
[project.optional-dependencies]
3232
tests = [
@@ -63,9 +63,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
6363

6464
[tool.pixi.dependencies]
6565
python = ">=3.10.15,<3.14"
66-
array-api-compat = ">=1.1.1"
66+
# array-api-compat = ">=1.10.0" # Do not release
6767

6868
[tool.pixi.pypi-dependencies]
69+
# Do not release: main at least @ gh#205
70+
array-api-compat = { git = "https://github.com/data-apis/array-api-compat.git" }
6971
array-api-extra = { path = ".", editable = true }
7072

7173
[tool.pixi.feature.lint.dependencies]
@@ -130,6 +132,35 @@ python = "~=3.10.0"
130132
[tool.pixi.feature.py313.dependencies]
131133
python = "~=3.13.0"
132134

135+
# Backends that can run on CPU-only hosts
136+
[tool.pixi.feature.backends.target.linux-64.dependencies]
137+
pytorch = "*"
138+
dask = "*"
139+
sparse = ">=0.15"
140+
jax = "*"
141+
142+
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
143+
pytorch = "*"
144+
dask = "*"
145+
sparse = ">=0.15"
146+
jax = "*"
147+
148+
[tool.pixi.feature.backends.target.win-64.dependencies]
149+
# pytorch = "*" # Package unavailable on Windows
150+
dask = "*"
151+
sparse = ">=0.15"
152+
# jax = "*" # Package unavailable on Windows
153+
154+
# Backends that require a GPU host and a CUDA driver
155+
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
156+
cupy = "*"
157+
158+
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
159+
# cupy = "*" # Package unavailable on macOSX
160+
161+
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
162+
cupy = "*"
163+
133164
[tool.pixi.environments]
134165
default = { solve-group = "default" }
135166
lint = { features = ["lint"], solve-group = "default" }
@@ -138,7 +169,9 @@ docs = { features = ["docs"], solve-group = "default" }
138169
dev = { features = ["lint", "tests", "docs", "dev"], solve-group = "default" }
139170
ci-py310 = ["py310", "tests"]
140171
ci-py313 = ["py313", "tests"]
141-
172+
# CUDA not available on free github actions
173+
ci-backends = ["py310", "tests", "backends"]
174+
tests-backends = ["py310", "tests", "backends", "cuda-backends"]
142175

143176
# pytest
144177

@@ -195,6 +228,8 @@ reportAny = false
195228
reportExplicitAny = false
196229
# data-apis/array-api-strict#6
197230
reportUnknownMemberType = false
231+
# no array-api-compat type stubs
232+
reportUnknownVariableType = false
198233

199234

200235
# Ruff
@@ -236,6 +271,7 @@ ignore = [
236271
"PLR09", # Too many <...>
237272
"PLR2004", # Magic value used in comparison
238273
"ISC001", # Conflicts with formatter
274+
"N801", # Class name should use CapWords convention
239275
"N802", # Function name should be lowercase
240276
"N806", # Variable in function should be lowercase
241277
]
@@ -271,6 +307,7 @@ checks = [
271307
"ES01",
272308
]
273309
exclude = [ # don't report on objects that match any of these regex
310+
'.*test_at.*',
274311
'.*test_funcs.*',
275312
'.*test_utils.*',
276313
'.*test_version.*',

src/array_api_extra/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.4.1.dev0"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

0 commit comments

Comments
 (0)