Skip to content

Commit 1b2ea91

Browse files
committed
✨ add CanArray* unop and binop protocols
Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent 858a3db commit 1b2ea91

File tree

8 files changed

+364
-14
lines changed

8 files changed

+364
-14
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
"typing-extensions>=4.14.1",
3030
"optype>=0.9.3; python_version < '3.11'",
3131
"optype>=0.12.2; python_version >= '3.11'",
32+
"tomli>=1.2.0 ; python_full_version < '3.11'",
3233
]
3334

3435
[project.urls]
@@ -123,9 +124,12 @@ ignore = [
123124
"D107", # Missing docstring in __init__
124125
"D203", # 1 blank line required before class docstring
125126
"D213", # Multi-line docstring summary should start at the second line
127+
"D401", # First line of docstring should be in imperative mood
126128
"FBT", # flake8-boolean-trap
127129
"FIX", # flake8-fixme
128130
"ISC001", # Conflicts with formatter
131+
"PLW1641", # Object does not implement `__hash__` method
132+
"PYI041", # Use `float` instead of `int | float`
129133
]
130134

131135
[tool.ruff.lint.pylint]

src/array_api_typing/_array.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
__all__ = (
22
"Array",
3+
"BoolArray",
34
"HasArrayNamespace",
5+
"NumericArray",
46
)
57

8+
from pathlib import Path
69
from types import ModuleType
7-
from typing import Literal, Protocol
10+
from typing import Literal, Never, Protocol, TypeAlias
811
from typing_extensions import TypeVar
912

13+
import optype as op
14+
15+
from ._utils import docstring_setter
16+
17+
# Load docstrings from TOML file
18+
try:
19+
import tomllib
20+
except ImportError:
21+
import tomli as tomllib # type: ignore[import-not-found, no-redef]
22+
23+
_docstrings_path = Path(__file__).parent / "_array_docstrings.toml"
24+
with _docstrings_path.open("rb") as f:
25+
_array_docstrings = tomllib.load(f)["docstrings"]
26+
1027
NS_co = TypeVar("NS_co", covariant=True, default=ModuleType)
28+
T_contra = TypeVar("T_contra", contravariant=True)
29+
R_co = TypeVar("R_co", covariant=True, default=Never)
1130

1231

1332
class HasArrayNamespace(Protocol[NS_co]):
@@ -33,8 +52,37 @@ def __array_namespace__(
3352
) -> NS_co: ...
3453

3554

55+
@docstring_setter(**_array_docstrings)
3656
class Array(
3757
HasArrayNamespace[NS_co],
38-
Protocol[NS_co],
58+
op.CanPosSelf,
59+
op.CanNegSelf,
60+
op.CanAddSame[T_contra, R_co],
61+
op.CanSubSame[T_contra, R_co],
62+
op.CanMulSame[T_contra, R_co],
63+
op.CanTruedivSame[T_contra, R_co],
64+
op.CanFloordivSame[T_contra, R_co],
65+
op.CanModSame[T_contra, R_co],
66+
op.CanPowSame[T_contra, R_co],
67+
Protocol[T_contra, R_co, NS_co],
3968
):
4069
"""Array API specification for array object attributes and methods."""
70+
71+
72+
BoolArray: TypeAlias = Array[bool, Array[float, Never, NS_co], NS_co]
73+
"""Array API specification for boolean array object attributes and methods.
74+
75+
Specifically, this type alias fills the `T_contra` type variable with
76+
`bool`, allowing for `bool` objects to be added, subtracted, multiplied, etc. to
77+
the array object.
78+
79+
"""
80+
81+
NumericArray: TypeAlias = Array[float | int, NS_co]
82+
"""Array API specification for numeric array object attributes and methods.
83+
84+
Specifically, this type alias fills the `T_contra` type variable with `float
85+
| int`, allowing for `float | int` objects to be added, subtracted, multiplied,
86+
etc. to the array object.
87+
88+
"""
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
[docstrings]
2+
__pos__ = '''
3+
Evaluates `+self_i` for each element of an array instance.
4+
5+
Returns:
6+
Self: An array containing the evaluated result for each element.
7+
The returned array must have the same data type as `self`.
8+
9+
See Also:
10+
array_api_typing.Positive
11+
12+
'''
13+
14+
__neg__ = '''
15+
Evaluates `-self_i` for each element of an array instance.
16+
17+
Returns:
18+
Self: an array containing the evaluated result for each element in
19+
`self`. The returned array must have a data type determined by Type
20+
Promotion Rules.
21+
22+
See Also:
23+
array_api_typing.Negative
24+
25+
'''
26+
27+
__add__ = '''
28+
Calculates the sum for each element of an array instance with the respective
29+
element of the array `other`.
30+
31+
Args:
32+
other: addend array. Must be compatible with `self` (see
33+
Broadcasting). Should have a numeric data type.
34+
35+
Returns:
36+
Self: an array containing the element-wise sums. The returned array
37+
must have a data type determined by Type Promotion Rules.
38+
39+
See Also:
40+
array_api_typing.Add
41+
42+
'''
43+
44+
__sub__ = '''
45+
Calculates the difference for each element of an array instance with the
46+
respective element of the array other.
47+
48+
The result of `self_i - other_i` must be the same as `self_i +
49+
(-other_i)` and must be governed by the same floating-point rules as
50+
addition (see `CanArrayAdd`).
51+
52+
Args:
53+
other: subtrahend array. Must be compatible with `self` (see
54+
Broadcasting). Should have a numeric data type.
55+
56+
Returns:
57+
Self: an array containing the element-wise differences. The returned
58+
array must have a data type determined by Type Promotion Rules.
59+
60+
See Also:
61+
array_api_typing.Subtract
62+
63+
'''
64+
65+
__mul__ = '''
66+
Calculates the product for each element of an array instance with the
67+
respective element of the array `other`.
68+
69+
Args:
70+
other: multiplicand array. Must be compatible with `self` (see
71+
Broadcasting). Should have a numeric data type.
72+
73+
Returns:
74+
Self: an array containing the element-wise products. The returned
75+
array must have a data type determined by Type Promotion Rules.
76+
77+
See Also:
78+
array_api_typing.Multiply
79+
80+
'''
81+
82+
__truediv__ = '''
83+
Evaluates `self_i / other_i` for each element of an array instance with the
84+
respective element of the array `other`.
85+
86+
Args:
87+
other: Must be compatible with `self` (see Broadcasting). Should have a
88+
numeric data type.
89+
90+
Returns:
91+
Self: an array containing the element-wise results. The returned array
92+
should have a floating-point data type determined by Type Promotion
93+
Rules.
94+
95+
See Also:
96+
array_api_typing.TrueDiv
97+
98+
'''
99+
100+
__floordiv__ = '''
101+
Evaluates `self_i // other_i` for each element of an array instance with the
102+
respective element of the array `other`.
103+
104+
Args:
105+
other: Must be compatible with `self` (see Broadcasting). Should have a
106+
numeric data type.
107+
108+
Returns:
109+
Self: an array containing the element-wise results. The returned array
110+
must have a data type determined by Type Promotion Rules.
111+
112+
See Also:
113+
array_api_typing.FloorDiv
114+
115+
'''
116+
117+
__mod__ = '''
118+
Evaluates `self_i % other_i` for each element of an array instance with the
119+
respective element of the array `other`.
120+
121+
Args:
122+
other: Must be compatible with `self` (see Broadcasting). Should have a
123+
numeric data type.
124+
125+
Returns:
126+
Self: an array containing the element-wise results. Each element-wise
127+
result must have the same sign as the respective element `other_i`.
128+
The returned array must have a floating-point data type determined
129+
by Type Promotion Rules.
130+
131+
See Also:
132+
array_api_typing.Remainder
133+
134+
'''
135+
136+
__pow__ = '''
137+
Calculates an implementation-dependent approximation of exponentiation by
138+
raising each element (the base) of an array instance to the power of
139+
`other_i` (the exponent), where `other_i` is the corresponding element of
140+
the array `other`.
141+
142+
Args:
143+
other: array whose elements correspond to the exponentiation exponent.
144+
Must be compatible with `self` (see Broadcasting). Should have a
145+
numeric data type.
146+
147+
Returns:
148+
Self: an array containing the element-wise results. The returned array
149+
must have a data type determined by Type Promotion Rules.
150+
151+
'''

src/array_api_typing/_namespace.py

Whitespace-only changes.

src/array_api_typing/_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Utility functions."""
2+
3+
from collections.abc import Callable
4+
from enum import Enum, auto
5+
from typing import Literal, TypeVar
6+
7+
ClassT = TypeVar("ClassT")
8+
DocstringTypes = str | None
9+
10+
11+
class _Sentinel(Enum):
12+
SKIP = auto()
13+
14+
15+
def set_docstrings(
16+
obj: type[ClassT],
17+
main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP,
18+
/,
19+
**method_docs: DocstringTypes,
20+
) -> type[ClassT]:
21+
"""Set the docstring for a class and its methods.
22+
23+
Args:
24+
obj: The class to set the docstring for.
25+
main: The main docstring for the class. If not provided, the
26+
class docstring will not be modified.
27+
method_docs: A mapping of method names to their docstrings. If a method
28+
is not provided, its docstring will not be modified.
29+
30+
Returns:
31+
The class with updated docstrings.
32+
33+
"""
34+
if main is not _Sentinel.SKIP:
35+
obj.__doc__ = main
36+
37+
for name, doc in method_docs.items():
38+
method = getattr(obj, name)
39+
method.__doc__ = doc
40+
return obj
41+
42+
43+
def docstring_setter(
44+
main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP,
45+
/,
46+
**method_docs: DocstringTypes,
47+
) -> Callable[[type[ClassT]], type[ClassT]]:
48+
"""Decorator to set docstrings for a class and its methods.
49+
50+
Args:
51+
main: The main docstring for the class. If not provided, the
52+
class docstring will not be modified.
53+
method_docs: A mapping of method names to their docstrings. If a method
54+
is not provided, its docstring will not be modified.
55+
56+
Returns:
57+
A decorator that sets the docstrings for the class and its methods.
58+
59+
"""
60+
61+
def decorator(cls: type[ClassT]) -> type[ClassT]:
62+
return set_docstrings(cls, main, **method_docs)
63+
64+
return decorator

tests/integration/test_numpy1.pyi

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,53 @@
1-
from typing import Any
1+
# Test array_api_typing with numpy < 2.0
22

3-
# requires numpy < 2
4-
import numpy.array_api as np # type: ignore[import-not-found]
3+
from typing import Any, Never, TypeAlias
4+
5+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
6+
from numpy import dtype, floating, integer
57

68
import array_api_typing as xpt
9+
from array_api_typing._array import BoolArray, NumericArray
10+
11+
F: TypeAlias = floating[Any]
12+
F32: TypeAlias = dtype # Note: np.array_api uses dtype objects.
13+
I: TypeAlias = integer[Any]
14+
I32: TypeAlias = dtype # Note: np.array_api uses dtype objects.
15+
16+
# Define an NDArray against which we can test the protocols
17+
nparr = np.eye(2)
18+
nparr_i32 = np.asarray([1], dtype=np.int32)
19+
nparr_f32 = np.asarray([1.0], dtype=np.float32)
20+
nparr_b = np.asarray([True], dtype=bool)
21+
22+
# =========================================================
23+
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`
24+
25+
arr_ns: xpt.HasArrayNamespace[Any] = nparr
26+
arr_ns_i32: xpt.HasArrayNamespace[Any] = nparr_i32
27+
arr_ns_f32: xpt.HasArrayNamespace[Any] = nparr_f32
28+
29+
# =========================================================
30+
# Ensure that `np.ndarray` instances are assignable to `xpt.Array`.
31+
32+
# Generic Array type
33+
arr_array: xpt.Array[Never] = nparr
34+
35+
# Float Array types
36+
arr_float: xpt.Array[float] = nparr_f32
37+
arr_f: xpt.Array[F] = nparr_f32
38+
arr_f32: xpt.Array[F32] = nparr_f32
39+
40+
# Integer Array types
41+
arr_int: xpt.Array[int, xpt.Array[float | int]] = nparr_i32
42+
arr_i: xpt.Array[I, xpt.Array[float | int]] = nparr_i32
43+
arr_i32: xpt.Array[I32, xpt.Array[F32 | I32]] = nparr_i32
44+
45+
# Boolean Array types
46+
arr_bool: xpt.Array[bool, xpt.Array[float | int | bool]] = nparr_b
47+
arr_b: xpt.Array[np.bool, xpt.Array[F | I | np.bool]] = nparr_b
748

8-
###
9-
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`.
49+
# =========================================================
50+
# Check np.ndarray against BoolArray and NumericArray type aliases
1051

11-
arr = np.eye(2)
12-
arr_namespace: xpt.HasArrayNamespace[Any] = arr
52+
boolarray: BoolArray = nparr_b
53+
numericarray: NumericArray = nparr_f32

0 commit comments

Comments
 (0)