Skip to content

Commit b3d7990

Browse files
committed
🚚 move HasArrayNamespace
Signed-off-by: nstarman <[email protected]>
1 parent 85ff8ac commit b3d7990

File tree

7 files changed

+107
-30
lines changed

7 files changed

+107
-30
lines changed

‎.github/workflows/ci.yml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,14 @@ jobs:
8888
python-version: "3.11"
8989
activate-environment: true
9090

91-
- name: get major numpy version
92-
id: numpy-major
91+
- name: get major.minor numpy version
92+
id: numpy-version
9393
run: |
94-
version=$(echo ${{ matrix.numpy-version }} | cut -c 1)
95-
echo "::set-output name=version::$version"
94+
version="${{ matrix.numpy-version }}"
95+
major=$(echo "$version" | cut -d. -f1)
96+
echo "major=$major" >> $GITHUB_OUTPUT
97+
minor=$(echo "$version" | cut -d. -f2)
98+
echo "minor=$minor" >> $GITHUB_OUTPUT
9699
97100
- name: install deps
98101
run: |
@@ -101,10 +104,21 @@ jobs:
101104
102105
# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
103106
- name: mypy
104-
run: >
105-
uv run --no-sync --active
106-
mypy --tb --no-incremental --cache-dir=/dev/null
107-
tests/integration/test_numpy${{ steps.numpy-major.outputs.version }}.pyi
107+
run: |
108+
major="${{ steps.numpy-version.outputs.major }}"
109+
minor="${{ steps.numpy-version.outputs.minor }}"
110+
111+
if [ "$major" -eq 1 ]; then
112+
file="test_numpy1.pyi"
113+
elif [ "$major" -eq 2 ] && [ "$minor" -lt 2 ]; then
114+
file="test_numpy2.0.pyi"
115+
else
116+
file="test_numpy2.2.pyi"
117+
fi
118+
119+
uv run --no-sync --active \
120+
mypy --tb --no-incremental --cache-dir=/dev/null \
121+
tests/integration/$file
108122
109123
# TODO: (based)pyright
110124

‎src/array_api_typing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
"__version_tuple__",
77
)
88

9-
from ._namespace import HasArrayNamespace
9+
from ._array import HasArrayNamespace
1010
from ._version import version as __version__, version_tuple as __version_tuple__

‎src/array_api_typing/_namespace.py renamed to ‎src/array_api_typing/_array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from typing import Literal, Protocol
55
from typing_extensions import TypeVar
66

7-
T_co = TypeVar("T_co", covariant=True, default=ModuleType)
7+
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
88

99

10-
class HasArrayNamespace(Protocol[T_co]):
10+
class HasArrayNamespace(Protocol[NamespaceT_co]):
1111
"""Protocol for classes that have an `__array_namespace__` method.
1212
13+
This `Protocol` is intended for use in static typing to ensure that an
14+
object has an `__array_namespace__` method that returns a namespace for
15+
array operations. This `Protocol` should not be used at runtime, for type
16+
checking or as a base class.
17+
1318
Example:
1419
>>> import array_api_typing as xpt
1520
>>>
@@ -27,4 +32,4 @@ class HasArrayNamespace(Protocol[T_co]):
2732

2833
def __array_namespace__(
2934
self, /, *, api_version: Literal["2021.12"] | None = None
30-
) -> T_co: ...
35+
) -> NamespaceT_co: ...

‎tests/integration/test_numpy1.pyi

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,34 @@
1-
from typing import Any
1+
# mypy: disable-error-code="no-redef"
22

3-
# requires numpy < 2
4-
import numpy.array_api as np
3+
from types import ModuleType
4+
from typing import TypeAlias
5+
6+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
57

68
import array_api_typing as xpt
79

8-
###
9-
# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`.
10+
# DType aliases
11+
F32: TypeAlias = np.float32
12+
I32: TypeAlias = np.int32
13+
14+
# Define NDArrays against which we can test the protocols
15+
nparr = np.eye(2)
16+
nparr_i32 = np.array([1], dtype=I32)
17+
nparr_f32 = np.array([1.0], dtype=F32)
18+
nparr_b = np.array([True], dtype=np.bool_)
19+
20+
# =========================================================
21+
# `xpt.HasArrayNamespace`
22+
23+
_: xpt.HasArrayNamespace[ModuleType] = nparr
24+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
27+
28+
# Check `__array_namespace__` method
29+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
30+
ns: ModuleType = a_ns.__array_namespace__()
1031

11-
arr = np.eye(2)
12-
arr_namespace: xpt.HasArrayNamespace[Any] = arr
32+
# Incorrect values are caught when using `__array_namespace__` and
33+
# backpropagated to the type of `a_ns`
34+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import Any, TypeAlias
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
9+
import array_api_typing as xpt
10+
11+
# DType aliases
12+
F32: TypeAlias = np.float32
13+
I32: TypeAlias = np.int32
14+
15+
# Define NDArrays against which we can test the protocols
16+
nparr: npt.NDArray[Any]
17+
nparr_i32: npt.NDArray[I32] = np.array([1], dtype=I32)
18+
nparr_f32: npt.NDArray[F32] = np.array([1.0], dtype=F32)
19+
nparr_b: npt.NDArray[np.bool_] = np.array([True], dtype=np.bool_)
20+
21+
# =========================================================
22+
# `xpt.HasArrayNamespace`
23+
24+
# Check assignment
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
27+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
28+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
29+
30+
# Check `__array_namespace__` method
31+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
32+
ns: ModuleType = a_ns.__array_namespace__()
33+
34+
# Incorrect values are caught when using `__array_namespace__` and
35+
# backpropagated to the type of `a_ns`
36+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from test_numpy2 import nparr
4+
5+
import array_api_typing as xpt
6+
7+
# Incorrect values are caught when using `__array_namespace__` and
8+
# backpropagated to the type of `a_ns`
9+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
10+
a_badns: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]
11+
a_badns.__array_namespace__() # triggers error above

‎tests/integration/test_numpy2.pyi

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)