Skip to content

Commit 55a039a

Browse files
committed
address review
1 parent 7ae5766 commit 55a039a

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

src/array_api_extra/_funcs.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# https://github.com/pylint-dev/pylint/issues/10112
88
from collections.abc import Callable # pylint: disable=import-error
9-
from typing import ClassVar
9+
from typing import ClassVar, Literal
1010

1111
from ._lib import _utils
1212
from ._lib._compat import (
@@ -659,11 +659,11 @@ class at: # pylint: disable=invalid-name
659659
idx: Index
660660
__slots__: ClassVar[tuple[str, str]] = ("idx", "x")
661661

662-
def __init__(self, x: Array, idx: Index = _undef, /):
662+
def __init__(self, x: Array, idx: Index = _undef, /) -> None:
663663
self.x = x
664664
self.idx = idx
665665

666-
def __getitem__(self, idx: Index) -> at:
666+
def __getitem__(self, idx: Index, /) -> at:
667667
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
668668
which looks prettier than ``at(x, slice(start, stop, step))``
669669
and feels more intuitive coming from the JAX documentation.
@@ -704,19 +704,16 @@ def _common(
704704

705705
x = self.x
706706

707-
if copy is True:
707+
if copy is None:
708+
writeable = is_writeable_array(x)
709+
copy = _is_update and not writeable
710+
elif copy:
708711
writeable = None
709-
elif copy is False:
712+
else:
710713
writeable = is_writeable_array(x)
711714
if not writeable:
712715
msg = "Cannot modify parameter in place"
713716
raise ValueError(msg)
714-
elif copy is None: # type: ignore[redundant-expr]
715-
writeable = is_writeable_array(x)
716-
copy = _is_update and not writeable
717-
else:
718-
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
719-
raise ValueError(msg)
720717

721718
if copy:
722719
try:
@@ -782,7 +779,9 @@ def set(self, y: Array, /, **kwargs: Untyped) -> Array:
782779

783780
def _iop(
784781
self,
785-
at_op: str,
782+
at_op: Literal[
783+
"set", "add", "subtract", "multiply", "divide", "power", "min", "max"
784+
],
786785
elwise_op: Callable[[Array, Array], Array],
787786
y: Array,
788787
/,

src/array_api_extra/_lib/_typing.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

33
import typing
4+
from collections.abc import Mapping
45
from types import ModuleType
5-
from typing import Any
6+
from typing import Any, Protocol
67

78
if typing.TYPE_CHECKING:
89
from typing_extensions import override
910

1011
# To be changed to a Protocol later (see data-apis/array-api#589)
11-
Array = Any # type: ignore[no-any-explicit]
12-
Device = Any # type: ignore[no-any-explicit]
13-
Index = Any # type: ignore[no-any-explicit]
1412
Untyped = Any # type: ignore[no-any-explicit]
13+
Array = Untyped
14+
Device = Untyped
15+
Index = Untyped
16+
17+
class CanAt(Protocol):
18+
@property
19+
def at(self) -> Mapping[Index, Untyped]: ...
20+
1521
else:
1622

1723
def no_op_decorator(f): # pyright: ignore[reportUnreachable]
1824
return f
1925

2026
override = no_op_decorator
2127

28+
CanAt = object
29+
2230
__all__ = ["ModuleType", "override"]
2331
if typing.TYPE_CHECKING:
2432
__all__ += ["Array", "Device", "Index", "Untyped"]

tests/test_at.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import contextmanager, suppress
44
from importlib import import_module
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Final
66

77
import numpy as np
88
import pytest
@@ -18,7 +18,7 @@
1818
if TYPE_CHECKING:
1919
from array_api_extra._lib._typing import Array, Untyped
2020

21-
all_libraries = (
21+
all_libraries: Final = (
2222
"array_api_strict",
2323
"numpy",
2424
"numpy_readonly",

0 commit comments

Comments
 (0)