From 1f34630ad0aa7e5a72174d62b93f86c35195149e Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 15:14:40 +0800 Subject: [PATCH 01/26] WIP: Add `top_k` compatibility This references the PR data-apis/array-api-tests#274. --- .github/workflows/array-api-tests-jax.yml | 10 +++ .github/workflows/array-api-tests.yml | 3 +- array_api_compat/dask/array/_aliases.py | 25 +++++- array_api_compat/jax/__init__.py | 27 ++++++ array_api_compat/numpy/_aliases.py | 103 +++++++++++++++++++++- array_api_compat/torch/_aliases.py | 4 +- 6 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/array-api-tests-jax.yml create mode 100644 array_api_compat/jax/__init__.py diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml new file mode 100644 index 00000000..4e93c7db --- /dev/null +++ b/.github/workflows/array-api-tests-jax.yml @@ -0,0 +1,10 @@ +name: Array API Tests (JAX) + +on: [push, pull_request] + +jobs: + array-api-tests-jax: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: jax + pytest-extra-args: -k top_k diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6e709438..a6361754 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -50,7 +50,8 @@ jobs: - name: Checkout array-api-tests uses: actions/checkout@v4 with: - repository: data-apis/array-api-tests + repository: JuliaPoo/array-api-tests + ref: wip-topk-tests submodules: 'true' path: array-api-tests - name: Set up Python ${{ matrix.python-version }} diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d2aac8b2..9a4b897d 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -150,6 +150,28 @@ def asarray( return da.asarray(obj, dtype=dtype, **kwargs) + +def top_k( + x: Array, + k: int, + /, + axis: Optional[int] = None, + *, + largest: bool = True, +) -> tuple[Array, Array]: + + if not largest: + k = -k + + # For now, perform the computation twice, + # since an equivalent to numpy's `take_along_axis` + # does not exist. + # See https://github.com/dask/dask/issues/3663. + args = da.argtopk(x, k, axis=axis).compute() + vals = da.topk(x, k, axis=axis).compute() + return vals, args + + from dask.array import ( # Element wise aliases arccos as acos, @@ -178,6 +200,7 @@ def asarray( 'bitwise_right_shift', 'concat', 'pow', 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type', + 'top_k'] _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py new file mode 100644 index 00000000..27762936 --- /dev/null +++ b/array_api_compat/jax/__init__.py @@ -0,0 +1,27 @@ +import jax +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Tuple + + from ..common._typing import Array + + +def top_k( + x: Array, + k: int, + /, + axis: Optional[int] = None, + *, + largest: bool = True, +) -> Tuple[Array, Array]: + + # `swapaxes` is used to implement + # the `axis` kwarg + x = jax.numpy.swapaxes(x, axis, -1) + vals, args = jax.lax.top_k(x, k) + vals = jax.numpy.swapaxes(vals, axis, -1) + args = jax.numpy.swapaxes(args, axis, -1) + return vals, args + + +__all__ = ['top_k'] \ No newline at end of file diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 70378716..34720ff0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,107 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) + +def top_k(a, k, /, *, axis=-1, largest=True): + """ + Returns the ``k`` largest/smallest elements and corresponding + indices along the given ``axis``. + + When ``axis`` is None, a flattened array is used. + + If ``largest`` is false, then the ``k`` smallest elements are returned. + + A tuple of ``(values, indices)`` is returned, where ``values`` and + ``indices`` of the largest/smallest elements of each row of the input + array in the given ``axis``. + + Parameters + ---------- + a: array_like + The source array + k: int + The number of largest/smallest elements to return. ``k`` must + be a positive integer and within indexable range specified by + ``axis``. + axis: int, optional + Axis along which to find the largest/smallest elements. + The default is -1 (the last axis). + If None, a flattened array is used. + largest: bool, optional + If True, largest elements are returned. Otherwise the smallest + are returned. + + Returns + ------- + tuple_of_array: tuple + The output tuple of ``(topk_values, topk_indices)``, where + ``topk_values`` are returned elements from the source array + (not necessarily in sorted order), and ``topk_indices`` are + the corresponding indices. + + See Also + -------- + argpartition : Indirect partition. + sort : Full sorting. + + Notes + ----- + The returned indices are not guaranteed to be sorted according to + the values. Furthermore, the returned indices are not guaranteed + to be the earliest/latest occurrence of the element. E.g., + ``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))`` + rather than ``(array([3]), array([0]))`` or + ``(array([3]), array([2]))``. + + Warning: The treatment of ``np.nan`` in the input array is undefined. + + Examples + -------- + >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) + >>> np.top_k(a, 2) + (array([[4, 5], + [4, 5], + [4, 5]]), + array([[3, 4], + [1, 0], + [1, 2]])) + >>> np.top_k(a, 2, axis=0) + (array([[3, 4, 3, 2, 2], + [5, 4, 5, 4, 5]]), + array([[2, 1, 1, 1, 2], + [1, 2, 2, 0, 0]])) + >>> a.flatten() + array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) + >>> np.top_k(a, 2, axis=None) + (array([5, 5]), array([ 5, 12])) + """ + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = np.asanyarray(a) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (np.s_[:],) * positive_axis + if largest: + indices_array = np.argpartition(arr, -k, axis=axis) + slice = slice_start + (np.s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = np.argpartition(arr, k-1, axis=axis) + slice = slice_start + (np.s_[:k],) + topk_indices = indices_array[slice] + + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) + + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -126,6 +227,6 @@ def asarray( __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + 'bitwise_right_shift', 'concat', 'pow', 'top_k'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index fb53e0ee..603dc15e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) +top_k = torch.topk + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', @@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'take', 'top_k'] _all_ignore = ['torch', 'get_xp'] From daaeaa5fc3823ed938cf0b5d23796c993a5aee35 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 15:32:02 +0800 Subject: [PATCH 02/26] idk what im doing --- array_api_compat/jax/__init__.py | 16 +++++----------- array_api_compat/numpy/_aliases.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index 27762936..010a7cfa 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -1,19 +1,13 @@ import jax -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Tuple - - from ..common._typing import Array - def top_k( - x: Array, - k: int, + x, + k, /, - axis: Optional[int] = None, + axis=None, *, - largest: bool = True, -) -> Tuple[Array, Array]: + largest=True, +): # `swapaxes` is used to implement # the `axis` kwarg diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 34720ff0..985611b4 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -62,7 +62,7 @@ tensordot = get_xp(np)(_aliases.tensordot) -def top_k(a, k, /, *, axis=-1, largest=True): +def top_k(a, k, /, axis=-1, *, largest=True): """ Returns the ``k`` largest/smallest elements and corresponding indices along the given ``axis``. From cfd838d9fe0f5829e1afa11f87a3d52f303e0850 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 15:36:17 +0800 Subject: [PATCH 03/26] meow --- array_api_compat/jax/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index 010a7cfa..94fce703 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -1,4 +1,4 @@ -import jax +from jax import * def top_k( x, @@ -11,10 +11,10 @@ def top_k( # `swapaxes` is used to implement # the `axis` kwarg - x = jax.numpy.swapaxes(x, axis, -1) - vals, args = jax.lax.top_k(x, k) - vals = jax.numpy.swapaxes(vals, axis, -1) - args = jax.numpy.swapaxes(args, axis, -1) + x = numpy.swapaxes(x, axis, -1) + vals, args = lax.top_k(x, k) + vals = numpy.swapaxes(vals, axis, -1) + args = numpy.swapaxes(args, axis, -1) return vals, args From 5c2027d379fbc6fec2e14131322dacd7627f9e02 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 15:42:21 +0800 Subject: [PATCH 04/26] Uwu --- .github/workflows/array-api-tests-jax.yml | 2 +- array_api_compat/jax/__init__.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml index 4e93c7db..9362a9ff 100644 --- a/.github/workflows/array-api-tests-jax.yml +++ b/.github/workflows/array-api-tests-jax.yml @@ -7,4 +7,4 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: jax - pytest-extra-args: -k top_k + pytest-extra-args: ./array-api-tests/array-api-tests/test_searching_functions.py::test_top_k diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index 94fce703..a20ecec9 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -1,5 +1,7 @@ +from jax.numpy import * # quick hack from jax import * + def top_k( x, k, From db94237f70172a12911762af9cd4cc26f64fbb3d Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 15:56:32 +0800 Subject: [PATCH 05/26] why is everything queued --- .github/workflows/array-api-tests-jax.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml index 9362a9ff..4e93c7db 100644 --- a/.github/workflows/array-api-tests-jax.yml +++ b/.github/workflows/array-api-tests-jax.yml @@ -7,4 +7,4 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: jax - pytest-extra-args: ./array-api-tests/array-api-tests/test_searching_functions.py::test_top_k + pytest-extra-args: -k top_k From 9651d6ddb014d50274e2f4535b22ef1820cc7916 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 16:02:17 +0800 Subject: [PATCH 06/26] please run workflow thx --- .github/workflows/array-api-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index a6361754..1c262223 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -47,6 +47,9 @@ jobs: uses: actions/checkout@v4 with: path: array-api-compat + fetch-depth: 0 + ref: ${{github.event.pull_request.head.ref}} + repository: ${{github.event.pull_request.head.repo.full_name}} - name: Checkout array-api-tests uses: actions/checkout@v4 with: From d5cf944ef7fd9d2e70c9aa53b61a0fe431a0f17f Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 17:25:18 +0800 Subject: [PATCH 07/26] uwu --- .github/workflows/array-api-tests.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 1c262223..f2793ebe 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -47,9 +47,8 @@ jobs: uses: actions/checkout@v4 with: path: array-api-compat - fetch-depth: 0 - ref: ${{github.event.pull_request.head.ref}} - repository: ${{github.event.pull_request.head.repo.full_name}} + repository: JuliaPoo/array-api-compat + ref: topk-compat - name: Checkout array-api-tests uses: actions/checkout@v4 with: From 624075f8cd99ab3a120431ec49bc44cc9c79db68 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 17:30:20 +0800 Subject: [PATCH 08/26] thing --- .github/workflows/array-api-tests.yml | 4 +--- jax-skips.txt | 0 jax-xfails.txt | 0 3 files changed, 1 insertion(+), 3 deletions(-) create mode 100644 jax-skips.txt create mode 100644 jax-xfails.txt diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index f2793ebe..198b8a04 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -47,15 +47,13 @@ jobs: uses: actions/checkout@v4 with: path: array-api-compat - repository: JuliaPoo/array-api-compat - ref: topk-compat - name: Checkout array-api-tests uses: actions/checkout@v4 with: repository: JuliaPoo/array-api-tests - ref: wip-topk-tests submodules: 'true' path: array-api-tests + ref: wip-topk-tests - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/jax-skips.txt b/jax-skips.txt new file mode 100644 index 00000000..e69de29b diff --git a/jax-xfails.txt b/jax-xfails.txt new file mode 100644 index 00000000..e69de29b From 6e698d13573936e67fb1b68f91a80a5e896cb30e Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 17:46:17 +0800 Subject: [PATCH 09/26] Point to correct submodule omg --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 198b8a04..ad0b9df5 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -53,7 +53,7 @@ jobs: repository: JuliaPoo/array-api-tests submodules: 'true' path: array-api-tests - ref: wip-topk-tests + ref: ci-wip-topk-tests - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: From 5f5cd6a64ab4f30087a2fb290ef70588c5fc313b Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 18:07:57 +0800 Subject: [PATCH 10/26] smth --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ad0b9df5..06c0bfdc 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -33,7 +33,7 @@ on: description: "Multiline string of environment variables to set for the test run." env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline" + PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} -k top_k --hypothesis-disable-deadline" jobs: tests: From cc590dd5573db66fc0c2607662b0e5da3e78cd98 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 18:11:11 +0800 Subject: [PATCH 11/26] smth --- .github/workflows/array-api-tests-jax.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml index 4e93c7db..c53964e6 100644 --- a/.github/workflows/array-api-tests-jax.yml +++ b/.github/workflows/array-api-tests-jax.yml @@ -7,4 +7,3 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: jax - pytest-extra-args: -k top_k From c2b52a6ad3e64dd71fe8c35b01e126e8c1fae384 Mon Sep 17 00:00:00 2001 From: Jules Date: Wed, 26 Jun 2024 18:19:50 +0800 Subject: [PATCH 12/26] why isnt it using the correct submodule --- .github/workflows/array-api-tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 06c0bfdc..7ac73844 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -40,7 +40,9 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.9', + #'3.10', '3.11', '3.12' + ] steps: - name: Checkout array-api-compat From 85008c53698c96326efaf773d14a6d1a1c891f86 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 10:48:38 +0800 Subject: [PATCH 13/26] pls work --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 7ac73844..e3b5188e 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -41,7 +41,7 @@ jobs: strategy: matrix: python-version: ['3.9', - #'3.10', '3.11', '3.12' + # '3.10', '3.11', '3.12' ] steps: From ddbda034028cd8c18bca9830218966e19a4e5ac8 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 10:54:16 +0800 Subject: [PATCH 14/26] trigger ci --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index e3b5188e..7ac73844 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -41,7 +41,7 @@ jobs: strategy: matrix: python-version: ['3.9', - # '3.10', '3.11', '3.12' + #'3.10', '3.11', '3.12' ] steps: From e10d187ce01b48ef64a3de3e4ecd7c3f21d5bb13 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 11:02:02 +0800 Subject: [PATCH 15/26] run against draft spec --- .github/workflows/array-api-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 7ac73844..eb972f8c 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -80,6 +80,7 @@ jobs: # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak + ARRAY_API_TESTS_VERSION: draft run: | export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat" cd ${GITHUB_WORKSPACE}/array-api-tests From 3dfa87823d5d6bc520aa91902d036f2337f3dd5d Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 11:06:50 +0800 Subject: [PATCH 16/26] trigger ci --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index eb972f8c..ff3b9920 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -41,7 +41,7 @@ jobs: strategy: matrix: python-version: ['3.9', - #'3.10', '3.11', '3.12' + # '3.10', '3.11', '3.12' ] steps: From f5e1093d4159f9c5eecd554f1a61c435108120c4 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 11:17:39 +0800 Subject: [PATCH 17/26] re-enable tests --- .github/workflows/array-api-tests.yml | 4 +- array_api_compat/jax/__init__.py | 38 ++++++++++++-- array_api_compat/numpy/_aliases.py | 72 --------------------------- 3 files changed, 34 insertions(+), 80 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ff3b9920..a17514fd 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -40,9 +40,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', - # '3.10', '3.11', '3.12' - ] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - name: Checkout array-api-compat diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index a20ecec9..b474d06d 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -13,11 +13,39 @@ def top_k( # `swapaxes` is used to implement # the `axis` kwarg - x = numpy.swapaxes(x, axis, -1) - vals, args = lax.top_k(x, k) - vals = numpy.swapaxes(vals, axis, -1) - args = numpy.swapaxes(args, axis, -1) - return vals, args + # x = numpy.swapaxes(x, axis, -1) + # vals, args = lax.top_k(x, k) + # vals = numpy.swapaxes(vals, axis, -1) + # args = numpy.swapaxes(args, axis, -1) + # return vals, args + + # The largest keyword can't be implemented with `jax.lax.top_k` + # efficiently so am using `jax.numpy` for now + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = asanyarray(a) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (s_[:],) * positive_axis + if largest: + indices_array = argpartition(arr, -k, axis=axis) + slice = slice_start + (s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = argpartition(arr, k-1, axis=axis) + slice = slice_start + (s_[:k],) + topk_indices = indices_array[slice] + + topk_values = take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) __all__ = ['top_k'] \ No newline at end of file diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 985611b4..ae28dac8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -63,78 +63,6 @@ def top_k(a, k, /, axis=-1, *, largest=True): - """ - Returns the ``k`` largest/smallest elements and corresponding - indices along the given ``axis``. - - When ``axis`` is None, a flattened array is used. - - If ``largest`` is false, then the ``k`` smallest elements are returned. - - A tuple of ``(values, indices)`` is returned, where ``values`` and - ``indices`` of the largest/smallest elements of each row of the input - array in the given ``axis``. - - Parameters - ---------- - a: array_like - The source array - k: int - The number of largest/smallest elements to return. ``k`` must - be a positive integer and within indexable range specified by - ``axis``. - axis: int, optional - Axis along which to find the largest/smallest elements. - The default is -1 (the last axis). - If None, a flattened array is used. - largest: bool, optional - If True, largest elements are returned. Otherwise the smallest - are returned. - - Returns - ------- - tuple_of_array: tuple - The output tuple of ``(topk_values, topk_indices)``, where - ``topk_values`` are returned elements from the source array - (not necessarily in sorted order), and ``topk_indices`` are - the corresponding indices. - - See Also - -------- - argpartition : Indirect partition. - sort : Full sorting. - - Notes - ----- - The returned indices are not guaranteed to be sorted according to - the values. Furthermore, the returned indices are not guaranteed - to be the earliest/latest occurrence of the element. E.g., - ``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))`` - rather than ``(array([3]), array([0]))`` or - ``(array([3]), array([2]))``. - - Warning: The treatment of ``np.nan`` in the input array is undefined. - - Examples - -------- - >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) - >>> np.top_k(a, 2) - (array([[4, 5], - [4, 5], - [4, 5]]), - array([[3, 4], - [1, 0], - [1, 2]])) - >>> np.top_k(a, 2, axis=0) - (array([[3, 4, 3, 2, 2], - [5, 4, 5, 4, 5]]), - array([[2, 1, 1, 1, 2], - [1, 2, 2, 0, 0]])) - >>> a.flatten() - array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) - >>> np.top_k(a, 2, axis=None) - (array([5, 5]), array([ 5, 12])) - """ if k <= 0: raise ValueError(f'k(={k}) provided must be positive.') From 1ab35afc2b9f4785d0598b24fbc6380a05af9809 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 11:28:41 +0800 Subject: [PATCH 18/26] Enable jax 64bit --- .github/workflows/array-api-tests-jax.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml index c53964e6..66672e55 100644 --- a/.github/workflows/array-api-tests-jax.yml +++ b/.github/workflows/array-api-tests-jax.yml @@ -7,3 +7,5 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: jax + extra-env-vars: | + JAX_ENABLE_X64=1 From 00a6a7a19d7ed0be36347d2276164b65fd12ac21 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 12:12:28 +0800 Subject: [PATCH 19/26] fix jax top_k --- array_api_compat/jax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index b474d06d..bbc42282 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -25,7 +25,7 @@ def top_k( raise ValueError(f'k(={k}) provided must be positive.') positive_axis: int - _arr = asanyarray(a) + _arr = asarray(a) if axis is None: arr = _arr.ravel() positive_axis = 0 From 966ae70a52687de2b8240c4bbc3cc6f05e495f81 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 12:52:25 +0800 Subject: [PATCH 20/26] fix jax --- array_api_compat/jax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index bbc42282..ca9fb43a 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -25,7 +25,7 @@ def top_k( raise ValueError(f'k(={k}) provided must be positive.') positive_axis: int - _arr = asarray(a) + _arr = asarray(x) if axis is None: arr = _arr.ravel() positive_axis = 0 From 2cfbd4f1e5f1205d60c19a3c4175d755a9dbcbab Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 15:57:01 +0800 Subject: [PATCH 21/26] fix jax top_k --- array_api_compat/jax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index ca9fb43a..ecfd5d16 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -43,8 +43,8 @@ def top_k( slice = slice_start + (s_[:k],) topk_indices = indices_array[slice] + topk_indices = topk_indices.astype(np.int_) topk_values = take_along_axis(arr, topk_indices, axis=axis) - return (topk_values, topk_indices) From 3820460cad54e6ce7012735f29b241877880c86b Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 16:01:06 +0800 Subject: [PATCH 22/26] fix jax top_k --- array_api_compat/jax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index ecfd5d16..423ec4c8 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -43,7 +43,7 @@ def top_k( slice = slice_start + (s_[:k],) topk_indices = indices_array[slice] - topk_indices = topk_indices.astype(np.int_) + topk_indices = topk_indices.astype(int_) topk_values = take_along_axis(arr, topk_indices, axis=axis) return (topk_values, topk_indices) From db06edea0a20736d01d8019c00e0eba10d1e9500 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 18:23:25 +0800 Subject: [PATCH 23/26] fix jax top_k --- .github/workflows/array-api-tests-jax.yml | 2 ++ array_api_compat/jax/__init__.py | 9 --------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml index 66672e55..59b70930 100644 --- a/.github/workflows/array-api-tests-jax.yml +++ b/.github/workflows/array-api-tests-jax.yml @@ -7,5 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: jax + # See https://github.com/google/jax/issues/22137 for reason behind skipped dtypes extra-env-vars: | JAX_ENABLE_X64=1 + ARRAY_API_TESTS_SKIP_DTYPES=uint8,uint16,uint32,uint64 diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index 423ec4c8..2a032442 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -10,15 +10,6 @@ def top_k( *, largest=True, ): - - # `swapaxes` is used to implement - # the `axis` kwarg - # x = numpy.swapaxes(x, axis, -1) - # vals, args = lax.top_k(x, k) - # vals = numpy.swapaxes(vals, axis, -1) - # args = numpy.swapaxes(args, axis, -1) - # return vals, args - # The largest keyword can't be implemented with `jax.lax.top_k` # efficiently so am using `jax.numpy` for now if k <= 0: From 6a5863fbc989d06a88600a7058be38a81e7a0831 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 18:54:46 +0800 Subject: [PATCH 24/26] fix ruff errors --- array_api_compat/jax/__init__.py | 43 +++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index 2a032442..d48520f6 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -1,5 +1,42 @@ -from jax.numpy import * # quick hack -from jax import * +from jax.numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, + # functions + zeros, + all, + isnan, + isfinite, + reshape +) +from jax.numpy import ( + asarray, + s_, + int_, + argpartition, + take_along_axis +) def top_k( @@ -39,4 +76,4 @@ def top_k( return (topk_values, topk_indices) -__all__ = ['top_k'] \ No newline at end of file +__all__ = ['top_k'] From 73b7e598d31e052199676ee2935569aa098da199 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 18:57:11 +0800 Subject: [PATCH 25/26] fix ruff errors --- array_api_compat/jax/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index d48520f6..430fd10c 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -76,4 +76,9 @@ def top_k( return (topk_values, topk_indices) -__all__ = ['top_k'] +__all__ = ['top_k', 'e', 'inf', 'nan', 'pi', 'newaxis', 'bool', + 'float32', 'float64', 'int8', 'int16', 'int32', + 'int64', 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', + 'can_cast', 'result_type', 'zeros', 'all', 'isnan', + 'isfinite', 'reshape'] From a1977c9a90c37a138ccc5b76be350f29791fb271 Mon Sep 17 00:00:00 2001 From: Jules Date: Thu, 27 Jun 2024 19:01:38 +0800 Subject: [PATCH 26/26] fix jax errors --- array_api_compat/jax/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py index 430fd10c..4282af15 100644 --- a/array_api_compat/jax/__init__.py +++ b/array_api_compat/jax/__init__.py @@ -26,6 +26,7 @@ # functions zeros, all, + any, isnan, isfinite, reshape @@ -81,4 +82,4 @@ def top_k( 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type', 'zeros', 'all', 'isnan', - 'isfinite', 'reshape'] + 'isfinite', 'reshape', 'any']