Skip to content

Commit e1e7644

Browse files
committed
✨HasDType
Signed-off-by: nstarman <[email protected]>
1 parent eaa42ce commit e1e7644

File tree

4 files changed

+66
-7
lines changed

4 files changed

+66
-7
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
__all__ = (
44
"Array",
55
"HasArrayNamespace",
6+
"HasDType",
67
"__version__",
78
"__version_tuple__",
89
)
910

10-
from ._array import Array, HasArrayNamespace
11+
from ._array import Array, HasArrayNamespace, HasDType
1112
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
1112

1213

1314
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -57,8 +58,32 @@ def __array_namespace__(
5758
...
5859

5960

61+
class HasDType(Protocol[DTypeT_co]):
62+
"""Protocol for array classes that have a data type attribute."""
63+
64+
@property
65+
def dtype(self, /) -> DTypeT_co:
66+
"""Data type of the array elements."""
67+
...
68+
69+
6070
class Array(
61-
HasArrayNamespace[NamespaceT_co],
62-
Protocol[NamespaceT_co],
71+
# ------ Attributes -------
72+
HasDType[DTypeT_co],
73+
# -------------------------
74+
Protocol[DTypeT_co, NamespaceT_co],
6375
):
64-
"""Array API specification for array object attributes and methods."""
76+
"""Array API specification for array object attributes and methods.
77+
78+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
79+
NamespaceT]`` where:
80+
81+
- `DTypeT` is the data type of the array elements.
82+
- `NamespaceT` is the type of the array namespace. It defaults to
83+
`ModuleType`, which is the most common form of array namespace (e.g.,
84+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
85+
`types.SimpleNamespace`, to allow for wrapper libraries to
86+
semi-dynamically define their own array namespaces based on the wrapped
87+
array type.
88+
89+
"""

tests/integration/test_numpy1p0.pyi

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# mypy: disable-error-code="no-redef"
22

33
from types import ModuleType
4-
from typing import TypeAlias
4+
from typing import Any
55

66
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
7+
from numpy import dtype
78

89
import array_api_typing as xpt
910

@@ -28,8 +29,25 @@ ns: ModuleType = a_ns.__array_namespace__()
2829
# backpropagated to the type of `a_ns`
2930
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3031

32+
# =========================================================
33+
# `xpt.HasDType`
34+
35+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
36+
# type annotate specific dtypes like `np.float32` or `np.int32`.
37+
38+
_: xpt.HasDType[dtype[Any]] = nparr
39+
_: xpt.HasDType[dtype[Any]] = nparr_i32
40+
_: xpt.HasDType[dtype[Any]] = nparr_f32
41+
3142
# =========================================================
3243
# `xpt.Array`
3344

3445
# Check NamespaceT_co assignment
35-
a_ns: xpt.Array[ModuleType] = nparr
46+
a_ns: xpt.Array[Any, ModuleType] = nparr
47+
48+
# Check DTypeT_co assignment
49+
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
50+
# type annotate specific dtypes like `np.float32` or `np.int32`.
51+
_: xpt.Array[dtype[Any]] = nparr
52+
_: xpt.Array[dtype[Any]] = nparr_i32
53+
_: xpt.Array[dtype[Any]] = nparr_f32

tests/integration/test_numpy2p0.pyi

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,23 @@ ns: ModuleType = a_ns.__array_namespace__()
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3737

38+
# =========================================================
39+
# `xpt.HasDType`
40+
41+
# Check DTypeT_co assignment
42+
_: xpt.HasDType[Any] = nparr
43+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
44+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
45+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
46+
3847
# =========================================================
3948
# `xpt.Array`
4049

4150
# Check NamespaceT_co assignment
42-
a_ns: xpt.Array[ModuleType] = nparr
51+
a_ns: xpt.Array[Any, ModuleType] = nparr
52+
53+
# Check DTypeT_co assignment
54+
_: xpt.Array[Any] = nparr
55+
_: xpt.Array[np.dtype[I32]] = nparr_i32
56+
_: xpt.Array[np.dtype[F32]] = nparr_f32
57+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

0 commit comments

Comments
 (0)