Skip to content

Support the copy keyword in asarray #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e75ba03
Fix some internal documentation
asmeurer Mar 19, 2024
41719af
Factor out the list of wrapped libraries in the tests
asmeurer Mar 19, 2024
f1068a3
Fix typo
asmeurer Mar 19, 2024
73047a9
Add a test for the copy flag in asarray
asmeurer Mar 20, 2024
fa36c20
Properly test copy=None with the dtype argument
asmeurer Mar 20, 2024
558e4ed
Move the asarray numpy implementation to numpy/_aliases
asmeurer Mar 20, 2024
440e1c1
Update xfails
asmeurer Mar 21, 2024
f57ac54
Add a CuPy specific implementation for asarray
asmeurer Mar 21, 2024
166c650
Update test to test that CuPy does not handle copy=False
asmeurer Mar 21, 2024
a194344
Update cupy buffer protocol copy test
asmeurer Mar 21, 2024
a1eea09
Merge branch 'main' into asarray-copy
asmeurer Mar 21, 2024
d84983a
Structure the copy flag in cupy.asarray better
asmeurer Mar 22, 2024
11d27dd
Remove no longer correct note from docstring
asmeurer Mar 22, 2024
f6b5ea2
Add dask.array specific implementation of asarray()
asmeurer Mar 22, 2024
7c0116c
Run the normal tests against different versions of numpy
asmeurer Mar 22, 2024
4d54461
Fix workflow synatx
asmeurer Mar 22, 2024
d7807e1
Install everything in one pip command
asmeurer Mar 22, 2024
354e007
Drop support for Python 3.8
asmeurer Mar 22, 2024
dfac540
Update extras_require in setup.py
asmeurer Mar 22, 2024
c7b5780
Fix ruff errors
asmeurer Mar 22, 2024
3e1f24c
Only run numpy tests for numpy 1.21
asmeurer Mar 22, 2024
cee1696
Don't include "jax.numpy" in the numpy-only tests
asmeurer Mar 25, 2024
c58fbec
Test Python 3.12 on CI
asmeurer Mar 25, 2024
2e5c759
Skip NumPy 1.21 in Python 3.12
asmeurer Mar 25, 2024
04551ed
Fix bash syntax
asmeurer Mar 25, 2024
f7fb29f
Only run numpy specific tests for numpy=dev
asmeurer Mar 25, 2024
689366a
Fix bash syntax
asmeurer Mar 25, 2024
c060cee
Run tests with -v
asmeurer Mar 25, 2024
4112eaf
Disable other libraries too for the numpy-only tests
asmeurer Mar 25, 2024
b171583
Add requirements-dev.txt
asmeurer Mar 25, 2024
8aa76b7
Add setuptools to requirements-dev.txt
asmeurer Mar 25, 2024
7105866
Use pip for the test install
asmeurer Mar 25, 2024
aeb1cc4
Fix workflow syntax
asmeurer Mar 25, 2024
f7e724e
Try again to fix workflow syntax
asmeurer Mar 25, 2024
d7e8532
Keep trying to fix the workflow syntax
asmeurer Mar 25, 2024
fbd6e1b
Try more syntax fixes
asmeurer Mar 25, 2024
6397108
Fix workflow syntax
asmeurer Mar 25, 2024
d0d068d
Merge branch 'main' into asarray-copy
asmeurer Mar 27, 2024
764d5ab
Better job names for docs build and deploy
asmeurer Mar 27, 2024
d5a7cc6
Merge branch 'main' into asarray-copy
asmeurer Mar 27, 2024
2dcd864
Merge branch 'asarray-copy' of github.com:asmeurer/array-api-compat i…
asmeurer Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- name: Checkout array-api-compat
Expand All @@ -55,16 +55,15 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
# NumPy 1.21 doesn't support Python 3.11. NumPy 2.0 doesn't support
# Python 3.8. There doesn't seem to be a way to put this in the numpy
# 1.21 config file.
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
# to put this in the numpy 1.21 config file.
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
run: |
python -m pip install --upgrade pip
python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }}
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
- name: Run the array API testsuite (${{ inputs.package-name }})
if: "! ((matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21')) || (matrix.python-version == '3.8' && inputs.package-name == 'numpy' && contains(inputs.xfails-file-extra, 'dev')))"
if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))"
env:
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }}
# This enables the NEP 50 type promotion behavior (without it a lot of
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Docs Build
on: [push, pull_request]

jobs:
build:
docs-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main

jobs:
deploy:
docs-deploy:
runs-on: ubuntu-latest
environment:
name: docs-deploy
Expand Down
24 changes: 20 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.21', '1.26', 'dev']
exclude:
- python-version: '3.11'
numpy-version: '1.21'
- python-version: '3.12'
numpy-version: '1.21'
fail-fast: true
steps:
- uses: actions/checkout@v4
Expand All @@ -15,11 +21,21 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch dask[array] jax[cpu]
if [ "${{ matrix.numpy-version }}" == "dev" ]; then
PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple'
elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then
PIP_EXTRA='numpy==1.21.*'
else
PIP_EXTRA='numpy==1.26.*'
fi
python -m pip install -r requirements-dev.txt $PIP_EXTRA

- name: Run Tests
run: |
pytest
if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then
PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask")
fi
pytest -v "${PYTEST_EXTRA[@]}"

# Make sure it installs
python setup.py install
python -m pip install .
92 changes: 4 additions & 88 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from typing import Optional, Sequence, Tuple, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
from ._typing import ndarray, Device, Dtype

from typing import NamedTuple
from types import ModuleType
import inspect

from ._helpers import _check_device, is_numpy_array, array_namespace
from ._helpers import _check_device

# These functions are modified from the NumPy versions.

# Creation functions add the device keyword (which does nothing for NumPy)

def arange(
start: Union[int, float],
/,
Expand Down Expand Up @@ -268,90 +268,6 @@ def var(
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
return xp.transpose(x, axes)

# Creation functions add the device keyword (which does nothing for NumPy)

# asarray also adds the copy keyword
def _asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
namespace = None,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().

See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.

"""
if namespace is None:
try:
xp = array_namespace(obj, _use_compat=False)
except ValueError:
# TODO: What about lists of arrays?
raise ValueError("A namespace must be specified for asarray() with non-array input")
elif isinstance(namespace, ModuleType):
xp = namespace
elif namespace == 'numpy':
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask.array':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")

_check_device(xp, device)
if is_numpy_array(obj):
import numpy as np
if hasattr(np, '_CopyMode'):
# Not present in older NumPys
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
COPY_TRUE = (True, np._CopyMode.ALWAYS)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
if copy in COPY_FALSE and namespace != "dask.array":
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
return xp.array(obj, copy=True, dtype=dtype)
return obj
elif namespace == "dask.array":
if copy in COPY_TRUE:
if dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
import numpy as np
obj = np.array(obj, dtype=dtype, copy=True)
return xp.array(obj, dtype=dtype)
else:
import dask.array as da
import numpy as np
if not isinstance(obj, da.Array):
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj

return xp.asarray(obj, dtype=dtype, **kwargs)

# np.reshape calls the keyword argument 'newshape' instead of 'shape'
def reshape(x: ndarray,
/,
Expand Down
57 changes: 51 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from functools import partial

import cupy as cp

from ..common import _aliases
from .._internal import get_xp

asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol

bool = cp.bool_

Expand Down Expand Up @@ -62,6 +61,52 @@
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)

_copy_default = object()

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: Optional[bool] = _copy_default,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().

See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
with cp.cuda.Device(device):
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
# in asarray in numpy/_aliases.py.
if copy is not _copy_default:
# A future version of CuPy will change the meaning of copy=False
# to mean no-copy. We don't know for certain what version it will
# be yet, so to avoid breaking that version, we use a different
# default value for copy so asarray(obj) with no copy kwarg will
# always do the copy-if-needed behavior.

# This will still need to be updated to remove the
# NotImplementedError for copy=False, but at least this won't
# break the default or existing behavior.
if copy is None:
copy = False
elif copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
kwargs['copy'] = copy

return cp.array(obj, dtype=dtype, **kwargs)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -73,7 +118,7 @@
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
Expand Down
47 changes: 42 additions & 5 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
if TYPE_CHECKING:
from typing import Optional, Union

from ...common._typing import Device, Dtype, Array
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol

import dask.array as da

Expand Down Expand Up @@ -76,10 +76,6 @@ def _dask_arange(
arange = get_xp(da)(_dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
Expand Down Expand Up @@ -113,6 +109,47 @@ def _dask_arange(
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)


# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
Array,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
**kwargs,
) -> Array:
"""
Array API compatibility wrapper for asarray().

See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if copy is False:
# copy=False is not yet implemented in dask
raise NotImplementedError("copy=False is not yet implemented")
elif copy is True:
if isinstance(obj, da.Array) and dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
obj = np.array(obj, dtype=dtype, copy=True)
return da.array(obj, dtype=dtype)
else:
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj

return da.asarray(obj, dtype=dtype, **kwargs)

from dask.array import (
# Element wise aliases
arccos as acos,
Expand Down
6 changes: 6 additions & 0 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@

from ..common._helpers import * # noqa: F403

try:
# Used in asarray(). Not present in older versions.
from numpy import _CopyMode # noqa: F401
except ImportError:
pass

__array_api_version__ = '2022.12'
Loading
Loading