From 806f037ce983030f053e8f3dcadbce644d425c10 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 9 Dec 2024 13:44:08 +0000 Subject: [PATCH 1/6] DNM depend on array-api-compat@staging --- .github/workflows/test-vendor.yml | 6 ++- pixi.lock | 63 +++++++++++++++---------------- pyproject.toml | 9 ++++- 3 files changed, 42 insertions(+), 36 deletions(-) diff --git a/.github/workflows/test-vendor.yml b/.github/workflows/test-vendor.yml index 20be3891..6e2662a1 100644 --- a/.github/workflows/test-vendor.yml +++ b/.github/workflows/test-vendor.yml @@ -29,7 +29,11 @@ jobs: - name: Checkout array-api-compat uses: actions/checkout@v4 with: - repository: data-apis/array-api-compat + # DNM + # repository: data-apis/array-api-compat + repository: crusaderky/array-api-compat + ref: d7ab986843cc9eb20882d7ccbf7248d78fcbd759 + # /DNM path: array-api-compat - name: Vendor array-api-extra into test package diff --git a/pixi.lock b/pixi.lock index 7161216a..d291f182 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9,7 +9,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda @@ -49,9 +48,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda @@ -85,9 +84,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda @@ -125,6 +124,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . ci-py313: channels: @@ -135,7 +135,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda @@ -175,9 +174,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda @@ -213,9 +212,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda @@ -255,6 +254,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . default: channels: @@ -265,7 +265,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda @@ -286,9 +285,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda @@ -304,9 +303,9 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda @@ -324,6 +323,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . dev: channels: @@ -335,7 +335,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.6-py313h78bf25f_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -459,10 +458,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py313h80202fe_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.6-py313h8f79df9_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -581,10 +580,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py313hf2da073_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.6-py313hfa70ccb_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -703,6 +702,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py313h574b89f_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . docs: channels: @@ -714,7 +714,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/linux-64/brotli-python-1.1.0-py313h46c70d0_2.conda @@ -782,10 +781,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py313h80202fe_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/brotli-python-1.1.0-py313h3579c5c_2.conda @@ -847,10 +846,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2 - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py313hf2da073_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/win-64/brotli-python-1.1.0-py313h5813708_2.conda @@ -914,6 +913,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py313h574b89f_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . lint: channels: @@ -924,7 +924,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.6-py313h78bf25f_0.conda - conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda @@ -995,7 +994,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.6-py313h8f79df9_0.conda - conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda @@ -1061,7 +1059,6 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.6-py313hfa70ccb_0.conda - conda: https://prefix.dev/conda-forge/noarch/basedmypy-2.8.0-pyhd8ed1ab_0.conda @@ -1126,6 +1123,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.28.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . tests: channels: @@ -1136,7 +1134,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.8.30-hbcca054_0.conda @@ -1176,9 +1173,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.8.30-hf0a4a13_0.conda @@ -1214,9 +1211,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.8.30-h56e8100_0.conda @@ -1256,6 +1253,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-he29a5d6_23.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hdffcdeb_23.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . packages: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -1289,23 +1287,22 @@ packages: - pkg:pypi/alabaster?source=hash-mapping size: 18684 timestamp: 1733750512696 -- conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.9.1-pyhd8ed1ab_0.conda - sha256: 32689f25dd97965043a5ca8a07ae3a9c27278258a16e574b0705bdca7656feff - md5: f2328337441baa8f669d2a830cfd0097 - depends: - - python >=3.8 - license: MIT - license_family: MIT - purls: - - pkg:pypi/array-api-compat?source=hash-mapping - size: 38213 - timestamp: 1730293860305 +- pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 + name: array-api-compat + version: 1.10.0rc1 + requires_dist: + - cupy ; extra == 'cupy' + - dask ; extra == 'dask' + - jax ; extra == 'jax' + - numpy ; extra == 'numpy' + - pytorch ; extra == 'pytorch' + - sparse>=0.15.1 ; extra == 'sparse' + requires_python: '>=3.9' - pypi: . name: array-api-extra version: 0.3.3.dev0 sha256: 43892c8bc9d9e1a1a1ff0e3911c62f86de23f9087e71ed66c2c89c4704246ec9 requires_dist: - - array-api-compat>=1.1.1 - furo>=2023.8.17 ; extra == 'docs' - myst-parser>=0.13 ; extra == 'docs' - sphinx-autodoc-typehints ; extra == 'docs' diff --git a/pyproject.toml b/pyproject.toml index 94293f73..f1a6f498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,9 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = ["array-api-compat>=1.1.1"] +# DNM +# dependencies = ["array-api-compat>=1.1.1"] +dependencies = [] [project.optional-dependencies] tests = [ @@ -63,9 +65,12 @@ platforms = ["linux-64", "osx-arm64", "win-64"] [tool.pixi.dependencies] python = ">=3.10.15,<3.14" -array-api-compat = ">=1.1.1" +# array-api-compat = ">=1.1.1" # DNM [tool.pixi.pypi-dependencies] +# DNM main plus #205, #207, #211 +array-api-compat = { git = "https://github.com/crusaderky/array-api-compat.git", rev = "d7ab986843cc9eb20882d7ccbf7248d78fcbd759" } + array-api-extra = { path = ".", editable = true } [tool.pixi.feature.lint.dependencies] From aa0d36436f7e9102b384f1b489139b79c8a0f278 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 10 Dec 2024 12:36:18 +0000 Subject: [PATCH 2/6] WIP at() method --- docs/api-reference.md | 1 + pyproject.toml | 3 + src/array_api_extra/__init__.py | 12 +- src/array_api_extra/_funcs.py | 292 ++++++++++++++++++++++++++- src/array_api_extra/_lib/_compat.py | 13 +- src/array_api_extra/_lib/_compat.pyi | 3 + tests/test_at.py | 153 ++++++++++++++ vendor_tests/test_vendor.py | 15 +- 8 files changed, 482 insertions(+), 10 deletions(-) create mode 100644 tests/test_at.py diff --git a/docs/api-reference.md b/docs/api-reference.md index ffe68f24..b43c960f 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -6,6 +6,7 @@ :nosignatures: :toctree: generated + at atleast_nd cov create_diagonal diff --git a/pyproject.toml b/pyproject.toml index f1a6f498..c8096d49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,6 +230,9 @@ ignore = [ "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "ISC001", # Conflicts with formatter + # "N802", # Function name should be lowercase + # "N806", # Variable in function should be lowercase + # "PD008", # pandas-use-of-dot-at ] [tool.ruff.lint.per-file-ignores] diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index d1107b1a..bd676fe6 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,12 +1,22 @@ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 -from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc +from ._funcs import ( + at, + atleast_nd, + cov, + create_diagonal, + expand_dims, + kron, + setdiff1d, + sinc, +) __version__ = "0.3.3.dev0" # pylint: disable=duplicate-code __all__ = [ "__version__", + "at", "atleast_nd", "cov", "create_diagonal", diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 3d961e2e..599048c3 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,15 +1,21 @@ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 -import typing +import operator import warnings - -if typing.TYPE_CHECKING: - from ._lib._typing import Array, ModuleType +from collections.abc import Callable +from typing import Any from ._lib import _utils -from ._lib._compat import array_namespace +from ._lib._compat import ( + array_namespace, + is_array_api_obj, + is_dask_array, + is_writeable_array, +) +from ._lib._typing import Array, ModuleType __all__ = [ + "at", "atleast_nd", "cov", "create_diagonal", @@ -548,3 +554,279 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device), ) return xp.sin(y) / y + + +_undef = object() + + +class at: + """ + Update operations for read-only arrays. + + This implements ``jax.numpy.ndarray.at`` for all backends. + + Parameters + ---------- + x : array + Input array. + idx : index, optional + You may use two alternate syntaxes:: + + at(x, idx).set(value) # or get(), add(), etc. + at(x)[idx].set(value) + + copy : bool, optional + True (default) + Ensure that the inputs are not modified. + False + Ensure that the update operation writes back to the input. + Raise ValueError if a copy cannot be avoided. + None + The array parameter *may* be modified in place if it is possible and + beneficial for performance. + You should not reuse it after calling this function. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer + + **kwargs: + If the backend supports an `at` method, any additional keyword + arguments are passed to it verbatim; e.g. this allows passing + ``indices_are_sorted=True`` to JAX. + + Returns + ------- + Updated input array. + + Examples + -------- + Given either of these equivalent expressions:: + + x = at(x)[1].add(2, copy=None) + x = at(x, 1).add(2, copy=None) + + If x is a JAX array, they are the same as:: + + x = x.at[1].add(2) + + If x is a read-only numpy array, they are the same as:: + + x = x.copy() + x[1] += 2 + + Otherwise, they are the same as:: + + x[1] += 2 + + Warning + ------- + When you use copy=None, you should always immediately overwrite + the parameter array:: + + x = at(x, 0).set(2, copy=None) + + The anti-pattern below must be avoided, as it will result in different behaviour + on read-only versus writeable arrays:: + + x = xp.asarray([0, 0, 0]) + y = at(x, 0).set(2, copy=None) + z = at(x, 1).set(3, copy=None) + + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! + + Warning + ------- + The array API standard does not support integer array indices. + The behaviour of update methods when the index is an array of integers + is undefined; this is particularly true when the index contains multiple + occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``. + + Note + ---- + `sparse `_ is not supported by update methods yet. + + See Also + -------- + `jax.numpy.ndarray.at `_ + """ + + x: Array + idx: Any + __slots__ = ("idx", "x") + + def __init__(self, x: Array, idx: Any = _undef, /): + self.x = x + self.idx = idx + + def __getitem__(self, idx: Any) -> Any: + """Allow for the alternate syntax ``at(x)[start:stop:step]``, + which looks prettier than ``at(x, slice(start, stop, step))`` + and feels more intuitive coming from the JAX documentation. + """ + if self.idx is not _undef: + msg = "Index has already been set" + raise ValueError(msg) + self.idx = idx + return self + + def _common( + self, + at_op: str, + y: Array = _undef, + /, + copy: bool | None = True, + xp: ModuleType | None = None, + _is_update: bool = True, + **kwargs: Any, + ) -> tuple[Any, None] | tuple[None, Array]: + """Perform common prepocessing. + + Returns + ------- + If the operation can be resolved by at[], (return value, None) + Otherwise, (None, preprocessed x) + """ + if self.idx is _undef: + msg = ( + "Index has not been set.\n" + "Usage: either\n" + " at(x, idx).set(value)\n" + "or\n" + " at(x)[idx].set(value)\n" + "(same for all other methods)." + ) + raise TypeError(msg) + + x = self.x + + if copy is True: + writeable = None + elif copy is False: + writeable = is_writeable_array(x) + if not writeable: + msg = "Cannot modify parameter in place" + raise ValueError(msg) + elif copy is None: + writeable = is_writeable_array(x) + copy = _is_update and not writeable + else: + msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] + raise ValueError(msg) + + if copy: + try: + at_ = x.at + except AttributeError: + # Emulate at[] behaviour for non-JAX arrays + # with a copy followed by an update + if xp is None: + xp = array_namespace(x) + # Create writeable copy of read-only numpy array + x = xp.asarray(x, copy=True) + if writeable is False: + # A copy of a read-only numpy array is writeable + writeable = None + else: + # Use JAX's at[] or other library that with the same duck-type API + args = (y,) if y is not _undef else () + return getattr(at_[self.idx], at_op)(*args, **kwargs), None + + if _is_update: + if writeable is None: + writeable = is_writeable_array(x) + if not writeable: + # sparse crashes here + msg = f"Array {x} has no `at` method and is read-only" + raise ValueError(msg) + + return None, x + + def get(self, **kwargs: Any) -> Any: + """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring + that the output is either a copy or a view; it also allows passing + keyword arguments to the backend. + """ + if kwargs.get("copy") is False: + if is_array_api_obj(self.idx): + # Boolean index. Note that the array API spec + # https://data-apis.org/array-api/latest/API_specification/indexing.html + # does not allow for list, tuple, and tuples of slices plus one or more + # one-dimensional array indices, although many backends support them. + # So this check will encounter a lot of false negatives in real life, + # which can be caught by testing the user code vs. array-api-strict. + msg = "get() with an array index always returns a copy" + raise ValueError(msg) + if is_dask_array(self.x): + msg = "get() on Dask arrays always returns a copy" + raise ValueError(msg) + + res, x = self._common("get", _is_update=False, **kwargs) + if res is not None: + return res + assert x is not None + return x[self.idx] + + def set(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] = y`` and return the update array""" + res, x = self._common("set", y, **kwargs) + if res is not None: + return res + assert x is not None + x[self.idx] = y + return x + + def _iop( + self, + at_op: str, + elwise_op: Callable[[Array, Array], Array], + y: Array, + /, + **kwargs: Any, + ) -> Array: + """x[idx] += y or equivalent in-place operation on a subset of x + + which is the same as saying + x[idx] = x[idx] + y + Note that this is not the same as + operator.iadd(x[idx], y) + Consider for example when x is a numpy array and idx is a fancy index, which + triggers a deep copy on __getitem__. + """ + res, x = self._common(at_op, y, **kwargs) + if res is not None: + return res + assert x is not None + x[self.idx] = elwise_op(x[self.idx], y) + return x + + def add(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] += y`` and return the updated array""" + return self._iop("add", operator.add, y, **kwargs) + + def subtract(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] -= y`` and return the updated array""" + return self._iop("subtract", operator.sub, y, **kwargs) + + def multiply(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] *= y`` and return the updated array""" + return self._iop("multiply", operator.mul, y, **kwargs) + + def divide(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] /= y`` and return the updated array""" + return self._iop("divide", operator.truediv, y, **kwargs) + + def power(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] **= y`` and return the updated array""" + return self._iop("power", operator.pow, y, **kwargs) + + def min(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" + xp = array_namespace(self.x) + y = xp.asarray(y) + return self._iop("min", xp.minimum, y, **kwargs) + + def max(self, y: Array, /, **kwargs: Any) -> Array: + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" + xp = array_namespace(self.x) + y = xp.asarray(y) + return self._iop("max", xp.maximum, y, **kwargs) diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index 03e47d19..7189d38e 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -6,14 +6,23 @@ from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] array_namespace, # pyright: ignore[reportUnknownVariableType] device, # pyright: ignore[reportUnknownVariableType] + is_array_api_obj, # pyright: ignore[reportUnknownVariableType] + is_dask_array, # pyright: ignore[reportUnknownVariableType] + is_writeable_array, # pyright: ignore[reportUnknownVariableType] ) except ImportError: from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] array_namespace, # pyright: ignore[reportUnknownVariableType] device, + is_array_api_obj, # pyright: ignore[reportUnknownVariableType] + is_dask_array, # pyright: ignore[reportUnknownVariableType] + is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] ) -__all__ = [ +__all__ = ( "array_namespace", "device", -] + "is_array_api_obj", + "is_dask_array", + "is_writeable_array", +) diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi index 3b4eb436..ec0ece58 100644 --- a/src/array_api_extra/_lib/_compat.pyi +++ b/src/array_api_extra/_lib/_compat.pyi @@ -11,3 +11,6 @@ def array_namespace( use_compat: bool | None = None, ) -> ArrayModule: ... def device(x: Array, /) -> Device: ... +def is_array_api_obj(x: object, /) -> bool: ... +def is_dask_array(x: object, /) -> bool: ... +def is_writeable_array(x: object, /) -> bool: ... diff --git a/tests/test_at.py b/tests/test_at.py new file mode 100644 index 00000000..d9ce49e6 --- /dev/null +++ b/tests/test_at.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from contextlib import contextmanager, suppress +from importlib import import_module +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from array_api_compat import ( + array_namespace, + is_dask_array, + is_pydata_sparse_array, + is_writeable_array, +) + +from array_api_extra import at + +if TYPE_CHECKING: + from array_api_extra._lib._typing import Array + +all_libraries = ( + "array_api_strict", + "numpy", + "numpy_readonly", + "cupy", + "torch", + "dask.array", + "sparse", + "jax.numpy", +) + + +@pytest.fixture(params=all_libraries) +def array(request): + library = request.param + if library == "numpy_readonly": + x = np.asarray([10.0, 20.0, 30.0]) + x.flags.writeable = False + else: + try: + lib = import_module(library) + except ImportError: + pytest.skip(f"{library} is not installed") + x = lib.asarray([10.0, 20.0, 30.0]) + return x + + +def assert_array_equal(a: Array, b: Array) -> None: + xp = array_namespace(a) + b = xp.asarray(b) + eq = xp.all(a == b) + if is_dask_array(a): + eq = eq.compute() + assert eq + + +@contextmanager +def assert_copy(array, copy: bool | None): + # dask arrays are writeable, but writing to them will hot-swap the + # dask graph inside the collection so that anything that references + # the original graph, i.e. the input collection, won't be mutated. + if copy is False and not is_writeable_array(array): + with pytest.raises((TypeError, ValueError)): + yield + return + + xp = array_namespace(array) + array_orig = xp.asarray(array, copy=True) + yield + + expect_copy = not is_writeable_array(array) if copy is None else copy + assert_array_equal(xp.all(array == array_orig), expect_copy) + + +@pytest.mark.parametrize("copy", [True, False, None]) +@pytest.mark.parametrize( + ("op", "arg", "expect"), + [ + ("set", 40.0, [10.0, 40.0, 40.0]), + ("add", 40.0, [10.0, 60.0, 70.0]), + ("subtract", 100.0, [10.0, -80.0, -70.0]), + ("multiply", 2.0, [10.0, 40.0, 60.0]), + ("divide", 2.0, [10.0, 10.0, 15.0]), + ("power", 2.0, [10.0, 400.0, 900.0]), + ("min", 25.0, [10.0, 20.0, 25.0]), + ("max", 25.0, [10.0, 25.0, 30.0]), + ], +) +def test_update_ops(array, copy, op, arg, expect): + if is_pydata_sparse_array(array): + pytest.skip("at() does not support updates on sparse arrays") + + with assert_copy(array, copy): + y = getattr(at(array, slice(1, None)), op)(arg, copy=copy) + assert isinstance(y, type(array)) + assert_array_equal(y, expect) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_get(array, copy): + expect_copy = copy + + # dask is mutable, but __getitem__ never returns a view + if is_dask_array(array): + if copy is False: + with pytest.raises(ValueError, match="always returns a copy"): + at(array, slice(2)).get(copy=False) + return + expect_copy = True + + with assert_copy(array, expect_copy): + y = at(array, slice(2)).get(copy=copy) + assert isinstance(y, type(array)) + assert_array_equal(y, [10.0, 20.0]) + # Let assert_copy test that y is a view or copy + with suppress(TypeError, ValueError): + y[:] = 40 + + +def test_get_bool_indices(array): + """get() with a boolean array index always returns a copy""" + # sparse violates the array API as it doesn't support + # a boolean index that is another sparse array. + # dask with dask index has NaN size, which complicates testing. + if is_pydata_sparse_array(array) or is_dask_array(array): + xp = np + else: + xp = array_namespace(array) + idx = xp.asarray([True, False, True]) + + with pytest.raises(ValueError, match="copy"): + at(array, idx).get(copy=False) + + assert_array_equal(at(array, idx).get(), [10.0, 30.0]) + + with assert_copy(array, True): + y = at(array, idx).get(copy=True) + assert_array_equal(y, [10.0, 30.0]) + # Let assert_copy test that y is a view or copy + with suppress(TypeError, ValueError): + y[:] = 40 + + +def test_copy_invalid(): + a = np.asarray([1, 2, 3]) + with pytest.raises(ValueError, match="copy"): + at(a, 0).set(4, copy="invalid") + + +def test_xp(): + a = np.asarray([1, 2, 3]) + b = at(a, 0).set(4, xp=np) + assert_array_equal(b, [4, 2, 3]) diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 8b00a375..d549d90e 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -5,10 +5,21 @@ def test_vendor_compat(): - from ._array_api_compat_vendor import array_namespace + from ._array_api_compat_vendor import ( # type: ignore[attr-defined] + array_namespace, + device, + is_array_api_obj, + is_dask_array, + is_writeable_array, + ) x = xp.asarray([1, 2, 3]) - assert array_namespace(x) is xp + assert array_namespace(x) is xp # type: ignore[no-untyped-call] + device(x) + assert is_array_api_obj(x) + assert not is_array_api_obj(123) + assert not is_dask_array(x) + assert is_writeable_array(x) def test_vendor_extra(): From de00bdeac14d8e03804eb44aeda8d2a2bb775792 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 10 Dec 2024 14:59:51 +0000 Subject: [PATCH 3/6] appease linter --- pyproject.toml | 5 ++-- src/array_api_extra/_funcs.py | 46 +++++++++++++++-------------- src/array_api_extra/_lib/_compat.py | 18 +++++------ src/array_api_extra/_lib/_typing.py | 4 ++- tests/test_at.py | 16 +++++----- vendor_tests/test_vendor.py | 2 +- 6 files changed, 48 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c8096d49..1e7b600d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,6 +195,8 @@ reportAny = false reportExplicitAny = false # data-apis/array-api-strict#6 reportUnknownMemberType = false +# no array-api-compat type stubs +reportUnknownVariableType = false # Ruff @@ -230,9 +232,6 @@ ignore = [ "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "ISC001", # Conflicts with formatter - # "N802", # Function name should be lowercase - # "N806", # Variable in function should be lowercase - # "PD008", # pandas-use-of-dot-at ] [tool.ruff.lint.per-file-ignores] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 599048c3..e0ca4fac 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -2,8 +2,10 @@ import operator import warnings -from collections.abc import Callable -from typing import Any + +# https://github.com/pylint-dev/pylint/issues/10112 +from collections.abc import Callable # pylint: disable=import-error +from typing import ClassVar from ._lib import _utils from ._lib._compat import ( @@ -12,7 +14,7 @@ is_dask_array, is_writeable_array, ) -from ._lib._typing import Array, ModuleType +from ._lib._typing import Array, Index, ModuleType, Untyped __all__ = [ "at", @@ -559,7 +561,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: _undef = object() -class at: +class at: # pylint: disable=invalid-name """ Update operations for read-only arrays. @@ -651,14 +653,14 @@ class at: """ x: Array - idx: Any - __slots__ = ("idx", "x") + idx: Index + __slots__: ClassVar[tuple[str, str]] = ("idx", "x") - def __init__(self, x: Array, idx: Any = _undef, /): + def __init__(self, x: Array, idx: Index = _undef, /): self.x = x self.idx = idx - def __getitem__(self, idx: Any) -> Any: + def __getitem__(self, idx: Index) -> at: """Allow for the alternate syntax ``at(x)[start:stop:step]``, which looks prettier than ``at(x, slice(start, stop, step))`` and feels more intuitive coming from the JAX documentation. @@ -677,8 +679,8 @@ def _common( copy: bool | None = True, xp: ModuleType | None = None, _is_update: bool = True, - **kwargs: Any, - ) -> tuple[Any, None] | tuple[None, Array]: + **kwargs: Untyped, + ) -> tuple[Untyped, None] | tuple[None, Array]: """Perform common prepocessing. Returns @@ -706,11 +708,11 @@ def _common( if not writeable: msg = "Cannot modify parameter in place" raise ValueError(msg) - elif copy is None: + elif copy is None: # type: ignore[redundant-expr] writeable = is_writeable_array(x) copy = _is_update and not writeable else: - msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] + msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable] raise ValueError(msg) if copy: @@ -741,7 +743,7 @@ def _common( return None, x - def get(self, **kwargs: Any) -> Any: + def get(self, **kwargs: Untyped) -> Untyped: """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring that the output is either a copy or a view; it also allows passing keyword arguments to the backend. @@ -766,7 +768,7 @@ def get(self, **kwargs: Any) -> Any: assert x is not None return x[self.idx] - def set(self, y: Array, /, **kwargs: Any) -> Array: + def set(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] = y`` and return the update array""" res, x = self._common("set", y, **kwargs) if res is not None: @@ -781,7 +783,7 @@ def _iop( elwise_op: Callable[[Array, Array], Array], y: Array, /, - **kwargs: Any, + **kwargs: Untyped, ) -> Array: """x[idx] += y or equivalent in-place operation on a subset of x @@ -799,33 +801,33 @@ def _iop( x[self.idx] = elwise_op(x[self.idx], y) return x - def add(self, y: Array, /, **kwargs: Any) -> Array: + def add(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] += y`` and return the updated array""" return self._iop("add", operator.add, y, **kwargs) - def subtract(self, y: Array, /, **kwargs: Any) -> Array: + def subtract(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] -= y`` and return the updated array""" return self._iop("subtract", operator.sub, y, **kwargs) - def multiply(self, y: Array, /, **kwargs: Any) -> Array: + def multiply(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] *= y`` and return the updated array""" return self._iop("multiply", operator.mul, y, **kwargs) - def divide(self, y: Array, /, **kwargs: Any) -> Array: + def divide(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] /= y`` and return the updated array""" return self._iop("divide", operator.truediv, y, **kwargs) - def power(self, y: Array, /, **kwargs: Any) -> Array: + def power(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] **= y`` and return the updated array""" return self._iop("power", operator.pow, y, **kwargs) - def min(self, y: Array, /, **kwargs: Any) -> Array: + def min(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" xp = array_namespace(self.x) y = xp.asarray(y) return self._iop("min", xp.minimum, y, **kwargs) - def max(self, y: Array, /, **kwargs: Any) -> Array: + def max(self, y: Array, /, **kwargs: Untyped) -> Array: """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" xp = array_namespace(self.x) y = xp.asarray(y) diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index 7189d38e..20bbda9d 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -4,19 +4,19 @@ try: from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] - array_namespace, # pyright: ignore[reportUnknownVariableType] - device, # pyright: ignore[reportUnknownVariableType] - is_array_api_obj, # pyright: ignore[reportUnknownVariableType] - is_dask_array, # pyright: ignore[reportUnknownVariableType] - is_writeable_array, # pyright: ignore[reportUnknownVariableType] + array_namespace, + device, + is_array_api_obj, + is_dask_array, + is_writeable_array, ) except ImportError: from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] - array_namespace, # pyright: ignore[reportUnknownVariableType] + array_namespace, device, - is_array_api_obj, # pyright: ignore[reportUnknownVariableType] - is_dask_array, # pyright: ignore[reportUnknownVariableType] - is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] + is_array_api_obj, + is_dask_array, + is_writeable_array, ) __all__ = ( diff --git a/src/array_api_extra/_lib/_typing.py b/src/array_api_extra/_lib/_typing.py index f84b1d20..aa99a1a0 100644 --- a/src/array_api_extra/_lib/_typing.py +++ b/src/array_api_extra/_lib/_typing.py @@ -10,6 +10,8 @@ # To be changed to a Protocol later (see data-apis/array-api#589) Array = Any # type: ignore[no-any-explicit] Device = Any # type: ignore[no-any-explicit] + Index = Any # type: ignore[no-any-explicit] + Untyped = Any # type: ignore[no-any-explicit] else: def no_op_decorator(f): # pyright: ignore[reportUnreachable] @@ -19,4 +21,4 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable] __all__ = ["ModuleType", "override"] if typing.TYPE_CHECKING: - __all__ += ["Array", "Device"] + __all__ += ["Array", "Device", "Index", "Untyped"] diff --git a/tests/test_at.py b/tests/test_at.py index d9ce49e6..1c8fa932 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from array_api_compat import ( +from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] array_namespace, is_dask_array, is_pydata_sparse_array, @@ -16,7 +16,7 @@ from array_api_extra import at if TYPE_CHECKING: - from array_api_extra._lib._typing import Array + from array_api_extra._lib._typing import Array, Untyped all_libraries = ( "array_api_strict", @@ -31,7 +31,7 @@ @pytest.fixture(params=all_libraries) -def array(request): +def array(request: pytest.FixtureRequest) -> Array: library = request.param if library == "numpy_readonly": x = np.asarray([10.0, 20.0, 30.0]) @@ -55,7 +55,7 @@ def assert_array_equal(a: Array, b: Array) -> None: @contextmanager -def assert_copy(array, copy: bool | None): +def assert_copy(array: Array, copy: bool | None) -> Untyped: # type: ignore[no-any-decorated] # dask arrays are writeable, but writing to them will hot-swap the # dask graph inside the collection so that anything that references # the original graph, i.e. the input collection, won't be mutated. @@ -86,7 +86,9 @@ def assert_copy(array, copy: bool | None): ("max", 25.0, [10.0, 25.0, 30.0]), ], ) -def test_update_ops(array, copy, op, arg, expect): +def test_update_ops( + array: Array, copy: bool | None, op: str, arg: float, expect: list[float] +): if is_pydata_sparse_array(array): pytest.skip("at() does not support updates on sparse arrays") @@ -97,7 +99,7 @@ def test_update_ops(array, copy, op, arg, expect): @pytest.mark.parametrize("copy", [True, False, None]) -def test_get(array, copy): +def test_get(array: Array, copy: bool | None): expect_copy = copy # dask is mutable, but __getitem__ never returns a view @@ -117,7 +119,7 @@ def test_get(array, copy): y[:] = 40 -def test_get_bool_indices(array): +def test_get_bool_indices(array: Array): """get() with a boolean array index always returns a copy""" # sparse violates the array API as it doesn't support # a boolean index that is another sparse array. diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index d549d90e..ec6d1c4a 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -14,7 +14,7 @@ def test_vendor_compat(): ) x = xp.asarray([1, 2, 3]) - assert array_namespace(x) is xp # type: ignore[no-untyped-call] + assert array_namespace(x) is xp device(x) assert is_array_api_obj(x) assert not is_array_api_obj(123) From bea44274d8fc1a16631a94b13f3c9a4d114f7b2a Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 10 Dec 2024 15:07:31 +0000 Subject: [PATCH 4/6] update lockfile --- pixi.lock | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/pixi.lock b/pixi.lock index d291f182..fc1dfcaa 100644 --- a/pixi.lock +++ b/pixi.lock @@ -102,7 +102,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/libiconv-1.17-hcfcfb64_2.conda - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda @@ -232,7 +232,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda @@ -312,7 +312,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-h2466b09_0.conda - conda: https://prefix.dev/conda-forge/win-64/python-3.13.1-h071d269_102_cp313.conda @@ -630,7 +630,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda @@ -871,7 +871,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py313hb4c8b1a_1.conda @@ -992,6 +992,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.28.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda @@ -1057,6 +1058,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.28.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2 - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda + - pypi: git+https://github.com/crusaderky/array-api-compat.git@d7ab986843cc9eb20882d7ccbf7248d78fcbd759 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_0.conda @@ -1086,7 +1088,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda @@ -1231,7 +1233,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-25_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libmpdec-4.0.0-h2466b09_0.conda - - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_8.conda - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda @@ -1283,6 +1285,7 @@ packages: depends: - python >=3.10 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/alabaster?source=hash-mapping size: 18684 @@ -1301,7 +1304,7 @@ packages: - pypi: . name: array-api-extra version: 0.3.3.dev0 - sha256: 43892c8bc9d9e1a1a1ff0e3911c62f86de23f9087e71ed66c2c89c4704246ec9 + sha256: d83424e47948250c57a3a1c1950c7765474ee2a6c76772768e3ecb0979d389c2 requires_dist: - furo>=2023.8.17 ; extra == 'docs' - myst-parser>=0.13 ; extra == 'docs' @@ -1645,6 +1648,7 @@ packages: - python_abi 3.10.* *_cp310 - tomli license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 294010 @@ -1659,6 +1663,7 @@ packages: - python_abi 3.13.* *_cp313 - tomli license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 371846 @@ -1673,6 +1678,7 @@ packages: - python_abi 3.10.* *_cp310 - tomli license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 292961 @@ -1687,6 +1693,7 @@ packages: - python_abi 3.13.* *_cp313 - tomli license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 370606 @@ -1702,6 +1709,7 @@ packages: - vc >=14.2,<15 - vc14_runtime >=14.29.30139 license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 320987 @@ -1717,6 +1725,7 @@ packages: - vc >=14.2,<15 - vc14_runtime >=14.29.30139 license: Apache-2.0 + license_family: APACHE purls: - pkg:pypi/coverage?source=hash-mapping size: 396811 @@ -2460,17 +2469,17 @@ packages: purls: [] size: 850553 timestamp: 1733762057506 -- conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.0-h2466b09_1.conda - sha256: 3342d6fe787f5830f7e8466d9c65c914bfd8d67220fb5673041b338cbba47afe - md5: 5b1f36012cc3d09c4eb9f24ad0e2c379 +- conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.47.2-h67fdade_0.conda + sha256: ecfc0182c3b2e63c870581be1fa0e4dbdfec70d2011cb4f5bde416ece26c41df + md5: ff00095330e0d35a16bd3bdbd1a2d3e7 depends: - ucrt >=10.0.20348.0 - vc >=14.2,<15 - vc14_runtime >=14.29.30139 license: Unlicense purls: [] - size: 892175 - timestamp: 1730208431651 + size: 891292 + timestamp: 1733762116902 - conda: https://prefix.dev/conda-forge/linux-64/libstdcxx-14.2.0-hc0a3c3a_1.conda sha256: 4661af0eb9bdcbb5fb33e5d0023b001ad4be828fccdcc56500059d56f9869462 md5: 234a5554c53625688d51062645337328 @@ -2853,6 +2862,7 @@ packages: constrains: - numpy-base <0a0 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping size: 7818907 @@ -2872,6 +2882,7 @@ packages: constrains: - numpy-base <0a0 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping size: 8530595 @@ -2891,6 +2902,7 @@ packages: constrains: - numpy-base <0a0 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping size: 5890950 @@ -2910,6 +2922,7 @@ packages: constrains: - numpy-base <0a0 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping size: 6550853 @@ -2929,6 +2942,7 @@ packages: constrains: - numpy-base <0a0 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping size: 6484276 @@ -2948,6 +2962,7 @@ packages: constrains: - numpy-base <0a0 license: BSD-3-Clause + license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping size: 7039077 @@ -3565,6 +3580,7 @@ packages: - sphinxcontrib-serializinghtml >=1.1.9 - tomli >=2.0 license: BSD-2-Clause + license_family: BSD purls: - pkg:pypi/sphinx?source=hash-mapping size: 1387076 From 7ae576681cbb0ea3116f2ae05762758bd97c7400 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 10 Dec 2024 15:11:34 +0000 Subject: [PATCH 5/6] fix import --- src/array_api_extra/_funcs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index e0ca4fac..b2d24257 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,6 +1,7 @@ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 import operator +import typing import warnings # https://github.com/pylint-dev/pylint/issues/10112 @@ -14,7 +15,9 @@ is_dask_array, is_writeable_array, ) -from ._lib._typing import Array, Index, ModuleType, Untyped + +if typing.TYPE_CHECKING: + from ._lib._typing import Array, Index, ModuleType, Untyped __all__ = [ "at", From 55a039a8edb9388020b079b3c05b5918dbe111bd Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 10 Dec 2024 16:36:57 +0000 Subject: [PATCH 6/6] address review --- src/array_api_extra/_funcs.py | 23 +++++++++++------------ src/array_api_extra/_lib/_typing.py | 16 ++++++++++++---- tests/test_at.py | 4 ++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index b2d24257..2bc69955 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -6,7 +6,7 @@ # https://github.com/pylint-dev/pylint/issues/10112 from collections.abc import Callable # pylint: disable=import-error -from typing import ClassVar +from typing import ClassVar, Literal from ._lib import _utils from ._lib._compat import ( @@ -659,11 +659,11 @@ class at: # pylint: disable=invalid-name idx: Index __slots__: ClassVar[tuple[str, str]] = ("idx", "x") - def __init__(self, x: Array, idx: Index = _undef, /): + def __init__(self, x: Array, idx: Index = _undef, /) -> None: self.x = x self.idx = idx - def __getitem__(self, idx: Index) -> at: + def __getitem__(self, idx: Index, /) -> at: """Allow for the alternate syntax ``at(x)[start:stop:step]``, which looks prettier than ``at(x, slice(start, stop, step))`` and feels more intuitive coming from the JAX documentation. @@ -704,19 +704,16 @@ def _common( x = self.x - if copy is True: + if copy is None: + writeable = is_writeable_array(x) + copy = _is_update and not writeable + elif copy: writeable = None - elif copy is False: + else: writeable = is_writeable_array(x) if not writeable: msg = "Cannot modify parameter in place" raise ValueError(msg) - elif copy is None: # type: ignore[redundant-expr] - writeable = is_writeable_array(x) - copy = _is_update and not writeable - else: - msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable] - raise ValueError(msg) if copy: try: @@ -782,7 +779,9 @@ def set(self, y: Array, /, **kwargs: Untyped) -> Array: def _iop( self, - at_op: str, + at_op: Literal[ + "set", "add", "subtract", "multiply", "divide", "power", "min", "max" + ], elwise_op: Callable[[Array, Array], Array], y: Array, /, diff --git a/src/array_api_extra/_lib/_typing.py b/src/array_api_extra/_lib/_typing.py index aa99a1a0..b877f96e 100644 --- a/src/array_api_extra/_lib/_typing.py +++ b/src/array_api_extra/_lib/_typing.py @@ -1,17 +1,23 @@ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 import typing +from collections.abc import Mapping from types import ModuleType -from typing import Any +from typing import Any, Protocol if typing.TYPE_CHECKING: from typing_extensions import override # To be changed to a Protocol later (see data-apis/array-api#589) - Array = Any # type: ignore[no-any-explicit] - Device = Any # type: ignore[no-any-explicit] - Index = Any # type: ignore[no-any-explicit] Untyped = Any # type: ignore[no-any-explicit] + Array = Untyped + Device = Untyped + Index = Untyped + + class CanAt(Protocol): + @property + def at(self) -> Mapping[Index, Untyped]: ... + else: def no_op_decorator(f): # pyright: ignore[reportUnreachable] @@ -19,6 +25,8 @@ def no_op_decorator(f): # pyright: ignore[reportUnreachable] override = no_op_decorator + CanAt = object + __all__ = ["ModuleType", "override"] if typing.TYPE_CHECKING: __all__ += ["Array", "Device", "Index", "Untyped"] diff --git a/tests/test_at.py b/tests/test_at.py index 1c8fa932..98ece9ba 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -2,7 +2,7 @@ from contextlib import contextmanager, suppress from importlib import import_module -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Final import numpy as np import pytest @@ -18,7 +18,7 @@ if TYPE_CHECKING: from array_api_extra._lib._typing import Array, Untyped -all_libraries = ( +all_libraries: Final = ( "array_api_strict", "numpy", "numpy_readonly",