Skip to content

Commit 9b98cc1

Browse files
committed
✨ HasDType
Signed-off-by: nstarman <[email protected]>
1 parent a502081 commit 9b98cc1

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ ignore = [
124124
"FBT", # flake8-boolean-trap
125125
"FIX", # flake8-fixme
126126
"ISC001", # Conflicts with formatter
127+
"PYI041", # Use `float` instead of `int | float`
128+
"TD002", # Missing author in TODO
129+
"TD003", # Missing issue link for this TODO
127130
]
128131

129132
[tool.ruff.lint.pylint]

src/array_api_typing/_array.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
1112

1213

1314
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -38,8 +39,32 @@ def __array_namespace__(
3839
) -> NamespaceT_co: ...
3940

4041

42+
class HasDType(Protocol[DTypeT_co]):
43+
"""Protocol for array classes that have a data type attribute."""
44+
45+
@property
46+
def dtype(self) -> DTypeT_co:
47+
"""Data type of the array elements."""
48+
...
49+
50+
4151
class Array(
42-
HasArrayNamespace[NamespaceT_co],
43-
Protocol[NamespaceT_co],
52+
# ------ Attributes -------
53+
HasDType[DTypeT_co],
54+
# -------------------------
55+
Protocol[DTypeT_co, NamespaceT_co],
4456
):
45-
"""Array API specification for array object attributes and methods."""
57+
"""Array API specification for array object attributes and methods.
58+
59+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
60+
NamespaceT]`` where:
61+
62+
- `DTypeT` is the data type of the array elements.
63+
- `NamespaceT` is the type of the array namespace. It defaults to
64+
`ModuleType`, which is the most common form of array namespace (e.g.,
65+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
66+
`types.SimpleNamespace`, to allow for wrapper libraries to
67+
semi-dynamically define their own array namespaces based on the wrapped
68+
array type.
69+
70+
"""

tests/integration/test_numpy1.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: disable-error-code="no-redef"
22

33
from types import ModuleType
4-
from typing import TypeAlias
4+
from typing import Any, TypeAlias
55

66
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
77

@@ -38,3 +38,9 @@ _: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3838

3939
# Check NamespaceT_co assignment
4040
a_ns: xpt.Array[ModuleType] = nparr
41+
42+
# Check DTypeT_co assignment
43+
_: xpt.Array[Any] = nparr
44+
_: xpt.Array[np.dtype[I32]] = nparr_i32
45+
_: xpt.Array[np.dtype[F32]] = nparr_f32
46+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

tests/integration/test_numpy2.0.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import numpy.typing as npt
99
import array_api_typing as xpt
1010

1111
# DType aliases
12+
F: TypeAlias = np.floating[Any]
1213
F32: TypeAlias = np.float32
14+
I: TypeAlias = np.integer[Any]
1315
I32: TypeAlias = np.int32
1416

1517
# Define NDArrays against which we can test the protocols
@@ -40,3 +42,9 @@ _: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
4042

4143
# Check NamespaceT_co assignment
4244
a_ns: xpt.Array[ModuleType] = nparr
45+
46+
# Check DTypeT_co assignment
47+
_: xpt.Array[Any] = nparr
48+
_: xpt.Array[np.dtype[I32]] = nparr_i32
49+
_: xpt.Array[np.dtype[F32]] = nparr_f32
50+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

0 commit comments

Comments
 (0)