|
6 | 6 |
|
7 | 7 | # https://github.com/pylint-dev/pylint/issues/10112
|
8 | 8 | from collections.abc import Callable # pylint: disable=import-error
|
9 |
| -from typing import ClassVar |
| 9 | +from typing import ClassVar, Literal |
10 | 10 |
|
11 | 11 | from ._lib import _utils
|
12 | 12 | from ._lib._compat import (
|
@@ -659,11 +659,11 @@ class at: # pylint: disable=invalid-name
|
659 | 659 | idx: Index
|
660 | 660 | __slots__: ClassVar[tuple[str, str]] = ("idx", "x")
|
661 | 661 |
|
662 |
| - def __init__(self, x: Array, idx: Index = _undef, /): |
| 662 | + def __init__(self, x: Array, idx: Index = _undef, /) -> None: |
663 | 663 | self.x = x
|
664 | 664 | self.idx = idx
|
665 | 665 |
|
666 |
| - def __getitem__(self, idx: Index) -> at: |
| 666 | + def __getitem__(self, idx: Index, /) -> at: |
667 | 667 | """Allow for the alternate syntax ``at(x)[start:stop:step]``,
|
668 | 668 | which looks prettier than ``at(x, slice(start, stop, step))``
|
669 | 669 | and feels more intuitive coming from the JAX documentation.
|
@@ -704,19 +704,16 @@ def _common(
|
704 | 704 |
|
705 | 705 | x = self.x
|
706 | 706 |
|
707 |
| - if copy is True: |
| 707 | + if copy is None: |
| 708 | + writeable = is_writeable_array(x) |
| 709 | + copy = _is_update and not writeable |
| 710 | + elif copy: |
708 | 711 | writeable = None
|
709 |
| - elif copy is False: |
| 712 | + else: |
710 | 713 | writeable = is_writeable_array(x)
|
711 | 714 | if not writeable:
|
712 | 715 | msg = "Cannot modify parameter in place"
|
713 | 716 | raise ValueError(msg)
|
714 |
| - elif copy is None: # type: ignore[redundant-expr] |
715 |
| - writeable = is_writeable_array(x) |
716 |
| - copy = _is_update and not writeable |
717 |
| - else: |
718 |
| - msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable] |
719 |
| - raise ValueError(msg) |
720 | 717 |
|
721 | 718 | if copy:
|
722 | 719 | try:
|
@@ -782,7 +779,9 @@ def set(self, y: Array, /, **kwargs: Untyped) -> Array:
|
782 | 779 |
|
783 | 780 | def _iop(
|
784 | 781 | self,
|
785 |
| - at_op: str, |
| 782 | + at_op: Literal[ |
| 783 | + "set", "add", "subtract", "multiply", "divide", "power", "min", "max" |
| 784 | + ], |
786 | 785 | elwise_op: Callable[[Array, Array], Array],
|
787 | 786 | y: Array,
|
788 | 787 | /,
|
|
0 commit comments