diff --git a/.github/workflows/test-vendor.yml b/.github/workflows/test-vendor.yml
index 20be3891..6e2662a1 100644
--- a/.github/workflows/test-vendor.yml
+++ b/.github/workflows/test-vendor.yml
@@ -29,7 +29,11 @@ jobs:
- name: Checkout array-api-compat
uses: actions/checkout@v4
with:
- repository: data-apis/array-api-compat
+ # DNM
+ # repository: data-apis/array-api-compat
+ repository: crusaderky/array-api-compat
+ ref: d7ab986843cc9eb20882d7ccbf7248d78fcbd759
+ # /DNM
path: array-api-compat
- name: Vendor array-api-extra into test package
diff --git a/docs/api-reference.md b/docs/api-reference.md
index ffe68f24..b43c960f 100644
--- a/docs/api-reference.md
+++ b/docs/api-reference.md
@@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated
+ at
atleast_nd
cov
create_diagonal
diff --git a/pixi.lock b/pixi.lock
index 7161216a..fc1dfcaa 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -9,7 +9,6 @@ environments:
linux-64:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda
- conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda
@@ -49,9 +48,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda
@@ -85,9 +84,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda
- conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda
@@ -103,7 +102,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/libiconv-1.17-hcfcfb64_2.conda
- conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda
- conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
@@ -125,6 +124,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
ci-py313:
channels:
@@ -135,7 +135,6 @@ environments:
linux-64:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda
- conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda
@@ -175,9 +174,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda
@@ -213,9 +212,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda
- conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda
@@ -233,7 +232,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda
- conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
@@ -255,6 +254,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
default:
channels:
@@ -265,7 +265,6 @@ environments:
linux-64:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda
- conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda
@@ -286,9 +285,9 @@ environments:
- conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda
- conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda
@@ -304,16 +303,16 @@ environments:
- conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda
- conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
- conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-h2466b09_0.conda
- conda: https://prefix.dev/conda-forge/win-64/python-3.13.1-h071d269_102_cp313.conda
@@ -324,6 +323,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
dev:
channels:
@@ -335,7 +335,6 @@ environments:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.6-py313h78bf25f_0.conda
- conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda
@@ -459,10 +458,10 @@ environments:
- conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda
- conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py313h80202fe_1.conda
- conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.6-py313h8f79df9_0.conda
- conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda
@@ -581,10 +580,10 @@ environments:
- conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py313hf2da073_1.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.6-py313hfa70ccb_0.conda
- conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda
@@ -631,7 +630,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda
- conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
@@ -703,6 +702,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2
- conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py313h574b89f_1.conda
- conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
docs:
channels:
@@ -714,7 +714,6 @@ environments:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda
- conda: https://prefix.dev/conda-forge/linux-64/brotli-python-1.1.0-py313h46c70d0_2.conda
@@ -782,10 +781,10 @@ environments:
- conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py313h80202fe_1.conda
- conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/brotli-python-1.1.0-py313h3579c5c_2.conda
@@ -847,10 +846,10 @@ environments:
- conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2
- conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py313hf2da073_1.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda
- conda: https://prefix.dev/conda-forge/win-64/brotli-python-1.1.0-py313h5813708_2.conda
@@ -872,7 +871,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
- conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py313hb4c8b1a_1.conda
@@ -914,6 +913,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2
- conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py313h574b89f_1.conda
- conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
lint:
channels:
@@ -924,7 +924,6 @@ environments:
linux-64:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.6-py313h78bf25f_0.conda
- conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda
@@ -993,9 +992,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.28.0-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.6-py313h8f79df9_0.conda
- conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda
@@ -1059,9 +1058,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.28.0-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2
- conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.6-py313hfa70ccb_0.conda
- conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda
@@ -1089,7 +1088,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda
- conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
@@ -1126,6 +1125,7 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.28.0-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda
- conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
tests:
channels:
@@ -1136,7 +1136,6 @@ environments:
linux-64:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
- conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda
- conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda
@@ -1176,9 +1175,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
osx-arm64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda
- conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda
@@ -1214,9 +1213,9 @@ environments:
- conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda
- conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
win-64:
- - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda
- conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda
- conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda
@@ -1234,7 +1233,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda
- conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda
- - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
+ - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
- conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda
- conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda
- conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda
@@ -1256,6 +1255,7 @@ environments:
- conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda
- conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda
+ - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
- pypi: .
packages:
- conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
@@ -1285,27 +1285,27 @@ packages:
depends:
- python >=3.10
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/alabaster?source=hash-mapping
size: 18684
timestamp: 1733750512696
-- conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda
- sha256: 32689f25dd97965043a5ca8a07ae3a9c27278258a16e574b0705bdca7656feff
- md5: f2328337441baa8f669d2a830cfd0097
- depends:
- - python >=3.8
- license: MIT
- license_family: MIT
- purls:
- - pkg:pypi/array-api-compat?source=hash-mapping
- size: 38213
- timestamp: 1730293860305
+- pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759
+ name: array-api-compat
+ version: 1.10.0rc1
+ requires_dist:
+ - cupy ; extra == 'cupy'
+ - dask ; extra == 'dask'
+ - jax ; extra == 'jax'
+ - numpy ; extra == 'numpy'
+ - pytorch ; extra == 'pytorch'
+ - sparse>=0.15.1 ; extra == 'sparse'
+ requires_python: '>=3.9'
- pypi: .
name: array-api-extra
version: 0.3.3.dev0
- sha256: 43892c8bc9d9e1a1a1ff0e3911c62f86de23f9087e71ed66c2c89c4704246ec9
+ sha256: d83424e47948250c57a3a1c1950c7765474ee2a6c76772768e3ecb0979d389c2
requires_dist:
- - array-api-compat>=1.1.1
- furo>=2023.8.17 ; extra == 'docs'
- myst-parser>=0.13 ; extra == 'docs'
- sphinx-autodoc-typehints ; extra == 'docs'
@@ -1648,6 +1648,7 @@ packages:
- python_abi 3.10.* *_cp310
- tomli
license: Apache-2.0
+ license_family: APACHE
purls:
- pkg:pypi/coverage?source=hash-mapping
size: 294010
@@ -1662,6 +1663,7 @@ packages:
- python_abi 3.13.* *_cp313
- tomli
license: Apache-2.0
+ license_family: APACHE
purls:
- pkg:pypi/coverage?source=hash-mapping
size: 371846
@@ -1676,6 +1678,7 @@ packages:
- python_abi 3.10.* *_cp310
- tomli
license: Apache-2.0
+ license_family: APACHE
purls:
- pkg:pypi/coverage?source=hash-mapping
size: 292961
@@ -1690,6 +1693,7 @@ packages:
- python_abi 3.13.* *_cp313
- tomli
license: Apache-2.0
+ license_family: APACHE
purls:
- pkg:pypi/coverage?source=hash-mapping
size: 370606
@@ -1705,6 +1709,7 @@ packages:
- vc >=14.2,<15
- vc14_runtime >=14.29.30139
license: Apache-2.0
+ license_family: APACHE
purls:
- pkg:pypi/coverage?source=hash-mapping
size: 320987
@@ -1720,6 +1725,7 @@ packages:
- vc >=14.2,<15
- vc14_runtime >=14.29.30139
license: Apache-2.0
+ license_family: APACHE
purls:
- pkg:pypi/coverage?source=hash-mapping
size: 396811
@@ -2463,17 +2469,17 @@ packages:
purls: []
size: 850553
timestamp: 1733762057506
-- conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda
- sha256: 3342d6fe787f5830f7e8466d9c65c914bfd8d67220fb5673041b338cbba47afe
- md5: 5b1f36012cc3d09c4eb9f24ad0e2c379
+- conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda
+ sha256: ecfc0182c3b2e63c870581be1fa0e4dbdfec70d2011cb4f5bde416ece26c41df
+ md5: ff00095330e0d35a16bd3bdbd1a2d3e7
depends:
- ucrt >=10.0.20348.0
- vc >=14.2,<15
- vc14_runtime >=14.29.30139
license: Unlicense
purls: []
- size: 892175
- timestamp: 1730208431651
+ size: 891292
+ timestamp: 1733762116902
- conda: https://prefix.dev/conda-forge/linux-64/libstdcxx-14.2.0-hc0a3c3a_1.conda
sha256: 4661af0eb9bdcbb5fb33e5d0023b001ad4be828fccdcc56500059d56f9869462
md5: 234a5554c53625688d51062645337328
@@ -2856,6 +2862,7 @@ packages:
constrains:
- numpy-base <0a0
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/numpy?source=hash-mapping
size: 7818907
@@ -2875,6 +2882,7 @@ packages:
constrains:
- numpy-base <0a0
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/numpy?source=hash-mapping
size: 8530595
@@ -2894,6 +2902,7 @@ packages:
constrains:
- numpy-base <0a0
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/numpy?source=hash-mapping
size: 5890950
@@ -2913,6 +2922,7 @@ packages:
constrains:
- numpy-base <0a0
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/numpy?source=hash-mapping
size: 6550853
@@ -2932,6 +2942,7 @@ packages:
constrains:
- numpy-base <0a0
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/numpy?source=hash-mapping
size: 6484276
@@ -2951,6 +2962,7 @@ packages:
constrains:
- numpy-base <0a0
license: BSD-3-Clause
+ license_family: BSD
purls:
- pkg:pypi/numpy?source=hash-mapping
size: 7039077
@@ -3568,6 +3580,7 @@ packages:
- sphinxcontrib-serializinghtml >=1.1.9
- tomli >=2.0
license: BSD-2-Clause
+ license_family: BSD
purls:
- pkg:pypi/sphinx?source=hash-mapping
size: 1387076
diff --git a/pyproject.toml b/pyproject.toml
index 94293f73..1e7b600d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,7 +26,9 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
-dependencies = ["array-api-compat>=1.1.1"]
+# DNM
+# dependencies = ["array-api-compat>=1.1.1"]
+dependencies = []
[project.optional-dependencies]
tests = [
@@ -63,9 +65,12 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
[tool.pixi.dependencies]
python = ">=3.10.15,<3.14"
-array-api-compat = ">=1.1.1"
+# array-api-compat = ">=1.1.1" # DNM
[tool.pixi.pypi-dependencies]
+# DNM main plus #205, #207, #211
+array-api-compat = { git = "https://github.com/crusaderky/array-api-compat.git", rev = "d7ab986843cc9eb20882d7ccbf7248d78fcbd759" }
+
array-api-extra = { path = ".", editable = true }
[tool.pixi.feature.lint.dependencies]
@@ -190,6 +195,8 @@ reportAny = false
reportExplicitAny = false
# data-apis/array-api-strict#6
reportUnknownMemberType = false
+# no array-api-compat type stubs
+reportUnknownVariableType = false
# Ruff
diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py
index d1107b1a..bd676fe6 100644
--- a/src/array_api_extra/__init__.py
+++ b/src/array_api_extra/__init__.py
@@ -1,12 +1,22 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
-from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
+from ._funcs import (
+ at,
+ atleast_nd,
+ cov,
+ create_diagonal,
+ expand_dims,
+ kron,
+ setdiff1d,
+ sinc,
+)
__version__ = "0.3.3.dev0"
# pylint: disable=duplicate-code
__all__ = [
"__version__",
+ "at",
"atleast_nd",
"cov",
"create_diagonal",
diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py
index 3d961e2e..2bc69955 100644
--- a/src/array_api_extra/_funcs.py
+++ b/src/array_api_extra/_funcs.py
@@ -1,15 +1,26 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
+import operator
import typing
import warnings
-if typing.TYPE_CHECKING:
- from ._lib._typing import Array, ModuleType
+# https://github.com/pylint-dev/pylint/issues/10112
+from collections.abc import Callable # pylint: disable=import-error
+from typing import ClassVar, Literal
from ._lib import _utils
-from ._lib._compat import array_namespace
+from ._lib._compat import (
+ array_namespace,
+ is_array_api_obj,
+ is_dask_array,
+ is_writeable_array,
+)
+
+if typing.TYPE_CHECKING:
+ from ._lib._typing import Array, Index, ModuleType, Untyped
__all__ = [
+ "at",
"atleast_nd",
"cov",
"create_diagonal",
@@ -548,3 +559,278 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
)
return xp.sin(y) / y
+
+
+_undef = object()
+
+
+class at: # pylint: disable=invalid-name
+ """
+ Update operations for read-only arrays.
+
+ This implements ``jax.numpy.ndarray.at`` for all backends.
+
+ Parameters
+ ----------
+ x : array
+ Input array.
+ idx : index, optional
+ You may use two alternate syntaxes::
+
+ at(x, idx).set(value) # or get(), add(), etc.
+ at(x)[idx].set(value)
+
+ copy : bool, optional
+ True (default)
+ Ensure that the inputs are not modified.
+ False
+ Ensure that the update operation writes back to the input.
+ Raise ValueError if a copy cannot be avoided.
+ None
+ The array parameter *may* be modified in place if it is possible and
+ beneficial for performance.
+ You should not reuse it after calling this function.
+ xp : array_namespace, optional
+ The standard-compatible namespace for `x`. Default: infer
+
+ **kwargs:
+ If the backend supports an `at` method, any additional keyword
+ arguments are passed to it verbatim; e.g. this allows passing
+ ``indices_are_sorted=True`` to JAX.
+
+ Returns
+ -------
+ Updated input array.
+
+ Examples
+ --------
+ Given either of these equivalent expressions::
+
+ x = at(x)[1].add(2, copy=None)
+ x = at(x, 1).add(2, copy=None)
+
+ If x is a JAX array, they are the same as::
+
+ x = x.at[1].add(2)
+
+ If x is a read-only numpy array, they are the same as::
+
+ x = x.copy()
+ x[1] += 2
+
+ Otherwise, they are the same as::
+
+ x[1] += 2
+
+ Warning
+ -------
+ When you use copy=None, you should always immediately overwrite
+ the parameter array::
+
+ x = at(x, 0).set(2, copy=None)
+
+ The anti-pattern below must be avoided, as it will result in different behaviour
+ on read-only versus writeable arrays::
+
+ x = xp.asarray([0, 0, 0])
+ y = at(x, 0).set(2, copy=None)
+ z = at(x, 1).set(3, copy=None)
+
+ In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
+ when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
+
+ Warning
+ -------
+ The array API standard does not support integer array indices.
+ The behaviour of update methods when the index is an array of integers
+ is undefined; this is particularly true when the index contains multiple
+ occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``.
+
+ Note
+ ----
+ `sparse `_ is not supported by update methods yet.
+
+ See Also
+ --------
+ `jax.numpy.ndarray.at `_
+ """
+
+ x: Array
+ idx: Index
+ __slots__: ClassVar[tuple[str, str]] = ("idx", "x")
+
+ def __init__(self, x: Array, idx: Index = _undef, /) -> None:
+ self.x = x
+ self.idx = idx
+
+ def __getitem__(self, idx: Index, /) -> at:
+ """Allow for the alternate syntax ``at(x)[start:stop:step]``,
+ which looks prettier than ``at(x, slice(start, stop, step))``
+ and feels more intuitive coming from the JAX documentation.
+ """
+ if self.idx is not _undef:
+ msg = "Index has already been set"
+ raise ValueError(msg)
+ self.idx = idx
+ return self
+
+ def _common(
+ self,
+ at_op: str,
+ y: Array = _undef,
+ /,
+ copy: bool | None = True,
+ xp: ModuleType | None = None,
+ _is_update: bool = True,
+ **kwargs: Untyped,
+ ) -> tuple[Untyped, None] | tuple[None, Array]:
+ """Perform common prepocessing.
+
+ Returns
+ -------
+ If the operation can be resolved by at[], (return value, None)
+ Otherwise, (None, preprocessed x)
+ """
+ if self.idx is _undef:
+ msg = (
+ "Index has not been set.\n"
+ "Usage: either\n"
+ " at(x, idx).set(value)\n"
+ "or\n"
+ " at(x)[idx].set(value)\n"
+ "(same for all other methods)."
+ )
+ raise TypeError(msg)
+
+ x = self.x
+
+ if copy is None:
+ writeable = is_writeable_array(x)
+ copy = _is_update and not writeable
+ elif copy:
+ writeable = None
+ else:
+ writeable = is_writeable_array(x)
+ if not writeable:
+ msg = "Cannot modify parameter in place"
+ raise ValueError(msg)
+
+ if copy:
+ try:
+ at_ = x.at
+ except AttributeError:
+ # Emulate at[] behaviour for non-JAX arrays
+ # with a copy followed by an update
+ if xp is None:
+ xp = array_namespace(x)
+ # Create writeable copy of read-only numpy array
+ x = xp.asarray(x, copy=True)
+ if writeable is False:
+ # A copy of a read-only numpy array is writeable
+ writeable = None
+ else:
+ # Use JAX's at[] or other library that with the same duck-type API
+ args = (y,) if y is not _undef else ()
+ return getattr(at_[self.idx], at_op)(*args, **kwargs), None
+
+ if _is_update:
+ if writeable is None:
+ writeable = is_writeable_array(x)
+ if not writeable:
+ # sparse crashes here
+ msg = f"Array {x} has no `at` method and is read-only"
+ raise ValueError(msg)
+
+ return None, x
+
+ def get(self, **kwargs: Untyped) -> Untyped:
+ """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
+ that the output is either a copy or a view; it also allows passing
+ keyword arguments to the backend.
+ """
+ if kwargs.get("copy") is False:
+ if is_array_api_obj(self.idx):
+ # Boolean index. Note that the array API spec
+ # https://data-apis.org/array-api/latest/API_specification/indexing.html
+ # does not allow for list, tuple, and tuples of slices plus one or more
+ # one-dimensional array indices, although many backends support them.
+ # So this check will encounter a lot of false negatives in real life,
+ # which can be caught by testing the user code vs. array-api-strict.
+ msg = "get() with an array index always returns a copy"
+ raise ValueError(msg)
+ if is_dask_array(self.x):
+ msg = "get() on Dask arrays always returns a copy"
+ raise ValueError(msg)
+
+ res, x = self._common("get", _is_update=False, **kwargs)
+ if res is not None:
+ return res
+ assert x is not None
+ return x[self.idx]
+
+ def set(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] = y`` and return the update array"""
+ res, x = self._common("set", y, **kwargs)
+ if res is not None:
+ return res
+ assert x is not None
+ x[self.idx] = y
+ return x
+
+ def _iop(
+ self,
+ at_op: Literal[
+ "set", "add", "subtract", "multiply", "divide", "power", "min", "max"
+ ],
+ elwise_op: Callable[[Array, Array], Array],
+ y: Array,
+ /,
+ **kwargs: Untyped,
+ ) -> Array:
+ """x[idx] += y or equivalent in-place operation on a subset of x
+
+ which is the same as saying
+ x[idx] = x[idx] + y
+ Note that this is not the same as
+ operator.iadd(x[idx], y)
+ Consider for example when x is a numpy array and idx is a fancy index, which
+ triggers a deep copy on __getitem__.
+ """
+ res, x = self._common(at_op, y, **kwargs)
+ if res is not None:
+ return res
+ assert x is not None
+ x[self.idx] = elwise_op(x[self.idx], y)
+ return x
+
+ def add(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] += y`` and return the updated array"""
+ return self._iop("add", operator.add, y, **kwargs)
+
+ def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] -= y`` and return the updated array"""
+ return self._iop("subtract", operator.sub, y, **kwargs)
+
+ def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] *= y`` and return the updated array"""
+ return self._iop("multiply", operator.mul, y, **kwargs)
+
+ def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] /= y`` and return the updated array"""
+ return self._iop("divide", operator.truediv, y, **kwargs)
+
+ def power(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] **= y`` and return the updated array"""
+ return self._iop("power", operator.pow, y, **kwargs)
+
+ def min(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
+ xp = array_namespace(self.x)
+ y = xp.asarray(y)
+ return self._iop("min", xp.minimum, y, **kwargs)
+
+ def max(self, y: Array, /, **kwargs: Untyped) -> Array:
+ """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
+ xp = array_namespace(self.x)
+ y = xp.asarray(y)
+ return self._iop("max", xp.maximum, y, **kwargs)
diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py
index 03e47d19..20bbda9d 100644
--- a/src/array_api_extra/_lib/_compat.py
+++ b/src/array_api_extra/_lib/_compat.py
@@ -4,16 +4,25 @@
try:
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
- array_namespace, # pyright: ignore[reportUnknownVariableType]
- device, # pyright: ignore[reportUnknownVariableType]
+ array_namespace,
+ device,
+ is_array_api_obj,
+ is_dask_array,
+ is_writeable_array,
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
- array_namespace, # pyright: ignore[reportUnknownVariableType]
+ array_namespace,
device,
+ is_array_api_obj,
+ is_dask_array,
+ is_writeable_array,
)
-__all__ = [
+__all__ = (
"array_namespace",
"device",
-]
+ "is_array_api_obj",
+ "is_dask_array",
+ "is_writeable_array",
+)
diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi
index 3b4eb436..ec0ece58 100644
--- a/src/array_api_extra/_lib/_compat.pyi
+++ b/src/array_api_extra/_lib/_compat.pyi
@@ -11,3 +11,6 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
+def is_array_api_obj(x: object, /) -> bool: ...
+def is_dask_array(x: object, /) -> bool: ...
+def is_writeable_array(x: object, /) -> bool: ...
diff --git a/src/array_api_extra/_lib/_typing.py b/src/array_api_extra/_lib/_typing.py
index f84b1d20..b877f96e 100644
--- a/src/array_api_extra/_lib/_typing.py
+++ b/src/array_api_extra/_lib/_typing.py
@@ -1,15 +1,23 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
import typing
+from collections.abc import Mapping
from types import ModuleType
-from typing import Any
+from typing import Any, Protocol
if typing.TYPE_CHECKING:
from typing_extensions import override
# To be changed to a Protocol later (see data-apis/array-api#589)
- Array = Any # type: ignore[no-any-explicit]
- Device = Any # type: ignore[no-any-explicit]
+ Untyped = Any # type: ignore[no-any-explicit]
+ Array = Untyped
+ Device = Untyped
+ Index = Untyped
+
+ class CanAt(Protocol):
+ @property
+ def at(self) -> Mapping[Index, Untyped]: ...
+
else:
def no_op_decorator(f): # pyright: ignore[reportUnreachable]
@@ -17,6 +25,8 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable]
override = no_op_decorator
+ CanAt = object
+
__all__ = ["ModuleType", "override"]
if typing.TYPE_CHECKING:
- __all__ += ["Array", "Device"]
+ __all__ += ["Array", "Device", "Index", "Untyped"]
diff --git a/tests/test_at.py b/tests/test_at.py
new file mode 100644
index 00000000..98ece9ba
--- /dev/null
+++ b/tests/test_at.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+from contextlib import contextmanager, suppress
+from importlib import import_module
+from typing import TYPE_CHECKING, Final
+
+import numpy as np
+import pytest
+from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
+ array_namespace,
+ is_dask_array,
+ is_pydata_sparse_array,
+ is_writeable_array,
+)
+
+from array_api_extra import at
+
+if TYPE_CHECKING:
+ from array_api_extra._lib._typing import Array, Untyped
+
+all_libraries: Final = (
+ "array_api_strict",
+ "numpy",
+ "numpy_readonly",
+ "cupy",
+ "torch",
+ "dask.array",
+ "sparse",
+ "jax.numpy",
+)
+
+
+@pytest.fixture(params=all_libraries)
+def array(request: pytest.FixtureRequest) -> Array:
+ library = request.param
+ if library == "numpy_readonly":
+ x = np.asarray([10.0, 20.0, 30.0])
+ x.flags.writeable = False
+ else:
+ try:
+ lib = import_module(library)
+ except ImportError:
+ pytest.skip(f"{library} is not installed")
+ x = lib.asarray([10.0, 20.0, 30.0])
+ return x
+
+
+def assert_array_equal(a: Array, b: Array) -> None:
+ xp = array_namespace(a)
+ b = xp.asarray(b)
+ eq = xp.all(a == b)
+ if is_dask_array(a):
+ eq = eq.compute()
+ assert eq
+
+
+@contextmanager
+def assert_copy(array: Array, copy: bool | None) -> Untyped: # type: ignore[no-any-decorated]
+ # dask arrays are writeable, but writing to them will hot-swap the
+ # dask graph inside the collection so that anything that references
+ # the original graph, i.e. the input collection, won't be mutated.
+ if copy is False and not is_writeable_array(array):
+ with pytest.raises((TypeError, ValueError)):
+ yield
+ return
+
+ xp = array_namespace(array)
+ array_orig = xp.asarray(array, copy=True)
+ yield
+
+ expect_copy = not is_writeable_array(array) if copy is None else copy
+ assert_array_equal(xp.all(array == array_orig), expect_copy)
+
+
+@pytest.mark.parametrize("copy", [True, False, None])
+@pytest.mark.parametrize(
+ ("op", "arg", "expect"),
+ [
+ ("set", 40.0, [10.0, 40.0, 40.0]),
+ ("add", 40.0, [10.0, 60.0, 70.0]),
+ ("subtract", 100.0, [10.0, -80.0, -70.0]),
+ ("multiply", 2.0, [10.0, 40.0, 60.0]),
+ ("divide", 2.0, [10.0, 10.0, 15.0]),
+ ("power", 2.0, [10.0, 400.0, 900.0]),
+ ("min", 25.0, [10.0, 20.0, 25.0]),
+ ("max", 25.0, [10.0, 25.0, 30.0]),
+ ],
+)
+def test_update_ops(
+ array: Array, copy: bool | None, op: str, arg: float, expect: list[float]
+):
+ if is_pydata_sparse_array(array):
+ pytest.skip("at() does not support updates on sparse arrays")
+
+ with assert_copy(array, copy):
+ y = getattr(at(array, slice(1, None)), op)(arg, copy=copy)
+ assert isinstance(y, type(array))
+ assert_array_equal(y, expect)
+
+
+@pytest.mark.parametrize("copy", [True, False, None])
+def test_get(array: Array, copy: bool | None):
+ expect_copy = copy
+
+ # dask is mutable, but __getitem__ never returns a view
+ if is_dask_array(array):
+ if copy is False:
+ with pytest.raises(ValueError, match="always returns a copy"):
+ at(array, slice(2)).get(copy=False)
+ return
+ expect_copy = True
+
+ with assert_copy(array, expect_copy):
+ y = at(array, slice(2)).get(copy=copy)
+ assert isinstance(y, type(array))
+ assert_array_equal(y, [10.0, 20.0])
+ # Let assert_copy test that y is a view or copy
+ with suppress(TypeError, ValueError):
+ y[:] = 40
+
+
+def test_get_bool_indices(array: Array):
+ """get() with a boolean array index always returns a copy"""
+ # sparse violates the array API as it doesn't support
+ # a boolean index that is another sparse array.
+ # dask with dask index has NaN size, which complicates testing.
+ if is_pydata_sparse_array(array) or is_dask_array(array):
+ xp = np
+ else:
+ xp = array_namespace(array)
+ idx = xp.asarray([True, False, True])
+
+ with pytest.raises(ValueError, match="copy"):
+ at(array, idx).get(copy=False)
+
+ assert_array_equal(at(array, idx).get(), [10.0, 30.0])
+
+ with assert_copy(array, True):
+ y = at(array, idx).get(copy=True)
+ assert_array_equal(y, [10.0, 30.0])
+ # Let assert_copy test that y is a view or copy
+ with suppress(TypeError, ValueError):
+ y[:] = 40
+
+
+def test_copy_invalid():
+ a = np.asarray([1, 2, 3])
+ with pytest.raises(ValueError, match="copy"):
+ at(a, 0).set(4, copy="invalid")
+
+
+def test_xp():
+ a = np.asarray([1, 2, 3])
+ b = at(a, 0).set(4, xp=np)
+ assert_array_equal(b, [4, 2, 3])
diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py
index 8b00a375..ec6d1c4a 100644
--- a/vendor_tests/test_vendor.py
+++ b/vendor_tests/test_vendor.py
@@ -5,10 +5,21 @@
def test_vendor_compat():
- from ._array_api_compat_vendor import array_namespace
+ from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
+ array_namespace,
+ device,
+ is_array_api_obj,
+ is_dask_array,
+ is_writeable_array,
+ )
x = xp.asarray([1, 2, 3])
assert array_namespace(x) is xp
+ device(x)
+ assert is_array_api_obj(x)
+ assert not is_array_api_obj(123)
+ assert not is_dask_array(x)
+ assert is_writeable_array(x)
def test_vendor_extra():