Skip to content

Commit 6ad0d17

Browse files
committed
code review
1 parent ddf6260 commit 6ad0d17

File tree

9 files changed

+44
-45
lines changed

9 files changed

+44
-45
lines changed

array_api_strict/_array_object.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
from __future__ import annotations
1717

1818
import operator
19+
import sys
1920
from collections.abc import Iterator
2021
from enum import IntEnum
2122
from types import ModuleType
22-
from typing import Any, Literal, SupportsIndex
23+
from typing import Any, Final, Literal, SupportsIndex
2324

2425
import numpy as np
2526
import numpy.typing as npt
2627

27-
from ._creation_functions import _undef, Undef, asarray
28+
from ._creation_functions import Undef, _undef, asarray
2829
from ._dtypes import (
2930
DType,
3031
_all_dtypes,
@@ -42,14 +43,16 @@
4243
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
4344
from ._typing import PyCapsule
4445

45-
try:
46-
from types import EllipsisType # Python >=3.10
47-
except ImportError:
46+
if sys.version_info >= (3, 10):
47+
from types import EllipsisType
48+
elif TYPE_CHECKING:
49+
from typing_extensions import EllipsisType
50+
else:
4851
EllipsisType = type(Ellipsis)
4952

5053

5154
class Device:
52-
_device: str
55+
_device: Final[str]
5356
__slots__ = ("_device", "__weakref__")
5457

5558
def __init__(self, device: str = "CPU_DEVICE"):
@@ -101,7 +104,7 @@ class Array:
101104
# Use a custom constructor instead of __init__, as manually initializing
102105
# this class is not supported API.
103106
@classmethod
104-
def _new(cls, x: npt.NDArray[Any], /, device: Device | None) -> Array:
107+
def _new(cls, x: npt.NDArray[Any] | np.generic, /, device: Device | None) -> Array:
105108
"""
106109
This is a private method for initializing the array API Array
107110
object.

array_api_strict/_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
inf = np.inf
55
nan = np.nan
66
pi = np.pi
7-
newaxis = np.newaxis
7+
newaxis: None = np.newaxis

array_api_strict/_creation_functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack
1313

1414
if TYPE_CHECKING:
15+
# TODO import from typing (requires Python >=3.13)
16+
from typing_extensions import TypeIs
17+
1518
# Circular import
1619
from ._array_object import Array, Device
1720

18-
1921
class Undef(Enum):
2022
UNDEF = 0
2123

@@ -44,7 +46,7 @@ def _check_valid_dtype(dtype: DType | None) -> None:
4446
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
4547

4648

47-
def _supports_buffer_protocol(obj: object) -> bool:
49+
def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
4850
try:
4951
memoryview(obj) # type: ignore[arg-type]
5052
except TypeError:

array_api_strict/_data_type_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from ._array_object import Array, Device
8-
from ._creation_functions import _check_device, _undef, Undef
8+
from ._creation_functions import Undef, _check_device, _undef
99
from ._dtypes import (
1010
DType,
1111
_all_dtypes,

array_api_strict/_dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import builtins
44
import warnings
5-
from typing import Any
5+
from typing import Any, Final
66

77
import numpy as np
88
import numpy.typing as npt
@@ -12,7 +12,7 @@
1212

1313

1414
class DType:
15-
_np_dtype: np.dtype[Any]
15+
_np_dtype: Final[np.dtype[Any]]
1616
__slots__ = ("_np_dtype", "__weakref__")
1717

1818
def __init__(self, np_dtype: npt.DTypeLike):

array_api_strict/_flags.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import warnings
2020
from collections.abc import Callable
2121
from types import TracebackType
22-
from typing import TYPE_CHECKING, Any, Collection, TypeVar
22+
from typing import TYPE_CHECKING, Any, Collection, TypeVar, cast
2323

2424
import array_api_strict
2525

@@ -30,6 +30,7 @@
3030
P = ParamSpec("P")
3131

3232
T = TypeVar("T")
33+
_CallableT = TypeVar("_CallableT", bound=Callable[..., object])
3334

3435

3536
supported_versions = (
@@ -389,8 +390,7 @@ def set_flags_from_environment() -> None:
389390

390391
# Decorators
391392

392-
393-
def requires_api_version(version: str) -> Callable[[Callable], Callable]:
393+
def requires_api_version(version: str) -> Callable[[_CallableT], _CallableT]:
394394
def decorator(func: Callable[P, T]) -> Callable[P, T]:
395395
@functools.wraps(func)
396396
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
@@ -403,7 +403,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
403403

404404
return wrapper
405405

406-
return decorator
406+
return cast(Callable[[_CallableT], _CallableT], decorator)
407407

408408

409409
def requires_data_dependent_shapes(func: Callable[P, T]) -> Callable[P, T]:
@@ -415,7 +415,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
415415
return wrapper
416416

417417

418-
def requires_extension(extension: str) -> Callable[[Callable], Callable]:
418+
def requires_extension(extension: str) -> Callable[[_CallableT], _CallableT]:
419419
def decorator(func: Callable[P, T]) -> Callable[P, T]:
420420
@functools.wraps(func)
421421
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
@@ -426,5 +426,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
426426
raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.")
427427
raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict")
428428
return func(*args, **kwargs)
429+
429430
return wrapper
430-
return decorator
431+
432+
return cast(Callable[[_CallableT], _CallableT], decorator)

array_api_strict/_linalg.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,7 @@
99

1010
from ._array_object import Array
1111
from ._data_type_functions import finfo
12-
from ._dtypes import (
13-
DType,
14-
_floating_dtypes,
15-
_numeric_dtypes,
16-
complex64,
17-
complex128,
18-
)
12+
from ._dtypes import DType, _floating_dtypes, _numeric_dtypes, complex64, complex128
1913
from ._elementwise_functions import conj
2014
from ._flags import get_array_api_strict_flags, requires_extension
2115
from ._manipulation_functions import reshape

array_api_strict/_typing.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,18 @@ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
5252
},
5353
)
5454

55-
DataTypes = TypedDict(
56-
"DataTypes",
57-
{
58-
"bool": DType,
59-
"float32": DType,
60-
"float64": DType,
61-
"complex64": DType,
62-
"complex128": DType,
63-
"int8": DType,
64-
"int16": DType,
65-
"int32": DType,
66-
"int64": DType,
67-
"uint8": DType,
68-
"uint16": DType,
69-
"uint32": DType,
70-
"uint64": DType,
71-
},
72-
total=False,
73-
)
55+
56+
class DataTypes(TypedDict, total=False):
57+
bool: DType
58+
float32: DType
59+
float64: DType
60+
complex64: DType
61+
complex128: DType
62+
int8: DType
63+
int16: DType
64+
int32: DType
65+
int64: DType
66+
uint8: DType
67+
uint16: DType
68+
uint32: DType
69+
uint64: DType

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ show_error_codes = true
4040
warn_redundant_casts = true
4141
warn_unused_ignores = true
4242
warn_unreachable = true
43+
strict_bytes = true
44+
local_partial_types = true
4345

4446
[[tool.mypy.overrides]]
4547
module = ["*.tests.*"]

0 commit comments

Comments
 (0)