Skip to content

Commit a502081

Browse files
committed
✨Array class
Signed-off-by: nstarman <[email protected]>
1 parent b3d7990 commit a502081

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Static typing support for the array API standard."""
22

33
__all__ = (
4+
"Array",
45
"HasArrayNamespace",
56
"__version__",
67
"__version_tuple__",
78
)
89

9-
from ._array import HasArrayNamespace
10+
from ._array import Array, HasArrayNamespace
1011
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
__all__ = ("HasArrayNamespace",)
1+
__all__ = (
2+
"Array",
3+
"HasArrayNamespace",
4+
)
25

36
from types import ModuleType
47
from typing import Literal, Protocol
@@ -33,3 +36,10 @@ class HasArrayNamespace(Protocol[NamespaceT_co]):
3336
def __array_namespace__(
3437
self, /, *, api_version: Literal["2021.12"] | None = None
3538
) -> NamespaceT_co: ...
39+
40+
41+
class Array(
42+
HasArrayNamespace[NamespaceT_co],
43+
Protocol[NamespaceT_co],
44+
):
45+
"""Array API specification for array object attributes and methods."""

tests/integration/test_numpy1.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ ns: ModuleType = a_ns.__array_namespace__()
3232
# Incorrect values are caught when using `__array_namespace__` and
3333
# backpropagated to the type of `a_ns`
3434
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
35+
36+
# =========================================================
37+
# `xpt.Array`
38+
39+
# Check NamespaceT_co assignment
40+
a_ns: xpt.Array[ModuleType] = nparr

tests/integration/test_numpy2.0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ ns: ModuleType = a_ns.__array_namespace__()
3434
# Incorrect values are caught when using `__array_namespace__` and
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
37+
38+
# =========================================================
39+
# `xpt.Array`
40+
41+
# Check NamespaceT_co assignment
42+
a_ns: xpt.Array[ModuleType] = nparr

0 commit comments

Comments
 (0)