Skip to content

POC: appease linter for gh-53 #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/test-vendor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ jobs:
- name: Checkout array-api-compat
uses: actions/checkout@v4
with:
repository: data-apis/array-api-compat
# DNM
# repository: data-apis/array-api-compat
repository: crusaderky/array-api-compat
ref: d7ab986843cc9eb20882d7ccbf7248d78fcbd759
# /DNM
path: array-api-compat

- name: Vendor array-api-extra into test package
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated

at
atleast_nd
cov
create_diagonal
Expand Down
105 changes: 59 additions & 46 deletions pixi.lock

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["array-api-compat>=1.1.1"]
# DNM
# dependencies = ["array-api-compat>=1.1.1"]
dependencies = []

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -63,9 +65,12 @@ platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10.15,<3.14"
array-api-compat = ">=1.1.1"
# array-api-compat = ">=1.1.1" # DNM

[tool.pixi.pypi-dependencies]
# DNM main plus #205, #207, #211
array-api-compat = { git = "https://github.com/crusaderky/array-api-compat.git", rev = "d7ab986843cc9eb20882d7ccbf7248d78fcbd759" }

array-api-extra = { path = ".", editable = true }

[tool.pixi.feature.lint.dependencies]
Expand Down Expand Up @@ -190,6 +195,8 @@ reportAny = false
reportExplicitAny = false
# data-apis/array-api-strict#6
reportUnknownMemberType = false
# no array-api-compat type stubs
reportUnknownVariableType = false


# Ruff
Expand Down
12 changes: 11 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
setdiff1d,
sinc,
)

__version__ = "0.3.3.dev0"

# pylint: disable=duplicate-code
__all__ = [
"__version__",
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down
292 changes: 289 additions & 3 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import operator
import typing
import warnings

if typing.TYPE_CHECKING:
from ._lib._typing import Array, ModuleType
# https://github.com/pylint-dev/pylint/issues/10112
from collections.abc import Callable # pylint: disable=import-error
from typing import ClassVar, Literal

from ._lib import _utils
from ._lib._compat import array_namespace
from ._lib._compat import (
array_namespace,
is_array_api_obj,
is_dask_array,
is_writeable_array,
)

if typing.TYPE_CHECKING:
from ._lib._typing import Array, Index, ModuleType, Untyped

__all__ = [
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down Expand Up @@ -548,3 +559,278 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
)
return xp.sin(y) / y


_undef = object()


class at: # pylint: disable=invalid-name
"""
Update operations for read-only arrays.

This implements ``jax.numpy.ndarray.at`` for all backends.

Parameters
----------
x : array
Input array.
idx : index, optional
You may use two alternate syntaxes::

at(x, idx).set(value) # or get(), add(), etc.
at(x)[idx].set(value)

copy : bool, optional
True (default)
Ensure that the inputs are not modified.
False
Ensure that the update operation writes back to the input.
Raise ValueError if a copy cannot be avoided.
None
The array parameter *may* be modified in place if it is possible and
beneficial for performance.
You should not reuse it after calling this function.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer

**kwargs:
If the backend supports an `at` method, any additional keyword
arguments are passed to it verbatim; e.g. this allows passing
``indices_are_sorted=True`` to JAX.

Returns
-------
Updated input array.

Examples
--------
Given either of these equivalent expressions::

x = at(x)[1].add(2, copy=None)
x = at(x, 1).add(2, copy=None)

If x is a JAX array, they are the same as::

x = x.at[1].add(2)

If x is a read-only numpy array, they are the same as::

x = x.copy()
x[1] += 2

Otherwise, they are the same as::

x[1] += 2

Warning
-------
When you use copy=None, you should always immediately overwrite
the parameter array::

x = at(x, 0).set(2, copy=None)

The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays::

x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2, copy=None)
z = at(x, 1).set(3, copy=None)

In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!

Warning
-------
The array API standard does not support integer array indices.
The behaviour of update methods when the index is an array of integers
is undefined; this is particularly true when the index contains multiple
occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``.

Note
----
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.

See Also
--------
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
"""

x: Array
idx: Index
__slots__: ClassVar[tuple[str, str]] = ("idx", "x")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO the linter should not force me to define the type of __slots__, because it's part of the python data model. This only adds attrition and reduces readability.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

Copy link
Member Author

@lucascolley lucascolley Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this is no longer needed

EDIT: only if at is made final


def __init__(self, x: Array, idx: Index = _undef, /) -> None:
self.x = x
self.idx = idx

def __getitem__(self, idx: Index, /) -> at:
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
msg = "Index has already been set"
raise ValueError(msg)
self.idx = idx
return self

def _common(
self,
at_op: str,
y: Array = _undef,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Untyped,
) -> tuple[Untyped, None] | tuple[None, Array]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it perhaps possible to @overload these cases?

Copy link
Contributor

@crusaderky crusaderky Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't because the return type depends on a duck-type test on x. And I'm definitely unwilling to explore writing a class HasAtMethod(Protocol) for a small internal function that is consumed exclusively 2 paragraph below.

Copy link

@jorenham jorenham Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely unwilling to explore writing a class HasAtMethod(Protocol) for a small internal function that is consumed exclusively 2 paragraph below.

I am willing, so here you go:

class _CanAt(Protocol):
    @property
    def at(self) -> Mapping[Index, Untyped] ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making this change - could you clarify how to use _CanAt @jorenham ?

"""Perform common prepocessing.

Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
msg = (
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)
raise TypeError(msg)

x = self.x

if copy is None:
writeable = is_writeable_array(x)
copy = _is_update and not writeable
elif copy:
writeable = None
else:
writeable = is_writeable_array(x)
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)

if copy:
try:
at_ = x.at
except AttributeError:
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
# Create writeable copy of read-only numpy array
x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
writeable = None
else:
# Use JAX's at[] or other library that with the same duck-type API
args = (y,) if y is not _undef else ()
return getattr(at_[self.idx], at_op)(*args, **kwargs), None

if _is_update:
if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
msg = f"Array {x} has no `at` method and is read-only"
raise ValueError(msg)

return None, x

def get(self, **kwargs: Untyped) -> Untyped:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
if kwargs.get("copy") is False:
if is_array_api_obj(self.idx):
# Boolean index. Note that the array API spec
# https://data-apis.org/array-api/latest/API_specification/indexing.html
# does not allow for list, tuple, and tuples of slices plus one or more
# one-dimensional array indices, although many backends support them.
# So this check will encounter a lot of false negatives in real life,
# which can be caught by testing the user code vs. array-api-strict.
msg = "get() with an array index always returns a copy"
raise ValueError(msg)
if is_dask_array(self.x):
msg = "get() on Dask arrays always returns a copy"
raise ValueError(msg)

res, x = self._common("get", _is_update=False, **kwargs)
if res is not None:
return res
assert x is not None
return x[self.idx]

def set(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = y
return x

def _iop(
self,
at_op: Literal[
"set", "add", "subtract", "multiply", "divide", "power", "min", "max"
],
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
**kwargs: Untyped,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x

which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
if res is not None:
return res
assert x is not None
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)

def max(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, **kwargs)
Loading
Loading