diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml new file mode 100644 index 00000000..59b70930 --- /dev/null +++ b/.github/workflows/array-api-tests-jax.yml @@ -0,0 +1,13 @@ +name: Array API Tests (JAX) + +on: [push, pull_request] + +jobs: + array-api-tests-jax: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: jax + # See https://github.com/google/jax/issues/22137 for reason behind skipped dtypes + extra-env-vars: | + JAX_ENABLE_X64=1 + ARRAY_API_TESTS_SKIP_DTYPES=uint8,uint16,uint32,uint64 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6e709438..a17514fd 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -33,7 +33,7 @@ on: description: "Multiline string of environment variables to set for the test run." env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline" + PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} -k top_k --hypothesis-disable-deadline" jobs: tests: @@ -50,9 +50,10 @@ jobs: - name: Checkout array-api-tests uses: actions/checkout@v4 with: - repository: data-apis/array-api-tests + repository: JuliaPoo/array-api-tests submodules: 'true' path: array-api-tests + ref: ci-wip-topk-tests - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -77,6 +78,7 @@ jobs: # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak + ARRAY_API_TESTS_VERSION: draft run: | export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat" cd ${GITHUB_WORKSPACE}/array-api-tests diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d2aac8b2..9a4b897d 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -150,6 +150,28 @@ def asarray( return da.asarray(obj, dtype=dtype, **kwargs) + +def top_k( + x: Array, + k: int, + /, + axis: Optional[int] = None, + *, + largest: bool = True, +) -> tuple[Array, Array]: + + if not largest: + k = -k + + # For now, perform the computation twice, + # since an equivalent to numpy's `take_along_axis` + # does not exist. + # See https://github.com/dask/dask/issues/3663. + args = da.argtopk(x, k, axis=axis).compute() + vals = da.topk(x, k, axis=axis).compute() + return vals, args + + from dask.array import ( # Element wise aliases arccos as acos, @@ -178,6 +200,7 @@ def asarray( 'bitwise_right_shift', 'concat', 'pow', 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type', + 'top_k'] _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py new file mode 100644 index 00000000..4282af15 --- /dev/null +++ b/array_api_compat/jax/__init__.py @@ -0,0 +1,85 @@ +from jax.numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, + # functions + zeros, + all, + any, + isnan, + isfinite, + reshape +) +from jax.numpy import ( + asarray, + s_, + int_, + argpartition, + take_along_axis +) + + +def top_k( + x, + k, + /, + axis=None, + *, + largest=True, +): + # The largest keyword can't be implemented with `jax.lax.top_k` + # efficiently so am using `jax.numpy` for now + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = asarray(x) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (s_[:],) * positive_axis + if largest: + indices_array = argpartition(arr, -k, axis=axis) + slice = slice_start + (s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = argpartition(arr, k-1, axis=axis) + slice = slice_start + (s_[:k],) + topk_indices = indices_array[slice] + + topk_indices = topk_indices.astype(int_) + topk_values = take_along_axis(arr, topk_indices, axis=axis) + return (topk_values, topk_indices) + + +__all__ = ['top_k', 'e', 'inf', 'nan', 'pi', 'newaxis', 'bool', + 'float32', 'float64', 'int8', 'int16', 'int32', + 'int64', 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', + 'can_cast', 'result_type', 'zeros', 'all', 'isnan', + 'isfinite', 'reshape', 'any'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 70378716..ae28dac8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,35 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) + +def top_k(a, k, /, axis=-1, *, largest=True): + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = np.asanyarray(a) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (np.s_[:],) * positive_axis + if largest: + indices_array = np.argpartition(arr, -k, axis=axis) + slice = slice_start + (np.s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = np.argpartition(arr, k-1, axis=axis) + slice = slice_start + (np.s_[:k],) + topk_indices = indices_array[slice] + + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) + + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -126,6 +155,6 @@ def asarray( __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + 'bitwise_right_shift', 'concat', 'pow', 'top_k'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index fb53e0ee..603dc15e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) +top_k = torch.topk + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', @@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'take', 'top_k'] _all_ignore = ['torch', 'get_xp'] diff --git a/jax-skips.txt b/jax-skips.txt new file mode 100644 index 00000000..e69de29b diff --git a/jax-xfails.txt b/jax-xfails.txt new file mode 100644 index 00000000..e69de29b