Skip to content

Commit 6cd171b

Browse files
crusaderkyjorenham
andauthored
Apply suggestions from code review
Co-authored-by: Joren Hammudoglu <[email protected]>
1 parent e2c6d40 commit 6cd171b

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

array_api_strict/_array_object.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ def __rand__(self, other: Array | int, /) -> Array:
10231023
res = self._array.__rand__(other._array)
10241024
return self.__class__._new(res, device=self.device)
10251025

1026-
def __ifloordiv__(self, other: Array | complex, /) -> Array:
1026+
def __ifloordiv__(self, other: Array | float, /) -> Array:
10271027
"""
10281028
Performs the operation __ifloordiv__.
10291029
"""
@@ -1034,7 +1034,7 @@ def __ifloordiv__(self, other: Array | complex, /) -> Array:
10341034
self._array.__ifloordiv__(other._array)
10351035
return self
10361036

1037-
def __rfloordiv__(self, other: Array | complex, /) -> Array:
1037+
def __rfloordiv__(self, other: Array | float, /) -> Array:
10381038
"""
10391039
Performs the operation __rfloordiv__.
10401040
"""
@@ -1105,7 +1105,7 @@ def __imod__(self, other: Array | complex, /) -> Array:
11051105
self._array.__imod__(other._array)
11061106
return self
11071107

1108-
def __rmod__(self, other: Array | complex, /) -> Array:
1108+
def __rmod__(self, other: Array | float, /) -> Array:
11091109
"""
11101110
Performs the operation __rmod__.
11111111
"""

array_api_strict/_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Undef(Enum):
2424

2525

2626
@contextmanager
27-
def allow_array() -> Generator[None, None, None]:
27+
def allow_array() -> Generator[None]:
2828
"""
2929
Temporarily enable Array.__array__. This is needed for np.array to parse
3030
list of lists of Array objects.

array_api_strict/_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def __exit__(
348348
exc_type: type[BaseException] | None,
349349
exc_value: BaseException | None,
350350
traceback: TracebackType | None,
351+
/,
351352
) -> None:
352353
set_array_api_strict_flags(**self.old_flags)
353354

array_api_strict/_utility_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def diff(
5656
# currently specified.
5757

5858
# NumPy does not support prepend=None or append=None
59-
kwargs: dict[str, Any] = dict(axis=axis, n=n)
59+
kwargs: dict[str, int | npt.NDArray[Any]] = {"axis": axis, "n", n}
6060
if prepend is not None:
6161
if prepend.device != x.device:
6262
raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.")

0 commit comments

Comments
 (0)