Skip to content

Commit b2af137

Browse files
authored
TYP: Type annotations overhaul, part 2 (#291)
1 parent 62507f4 commit b2af137

File tree

6 files changed

+26
-9
lines changed

6 files changed

+26
-9
lines changed

array_api_compat/common/_aliases.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def eye(
7373

7474
def full(
7575
shape: Union[int, Tuple[int, ...]],
76-
fill_value: complex,
76+
fill_value: bool | int | float | complex,
7777
xp: Namespace,
7878
*,
7979
dtype: Optional[DType] = None,
@@ -86,7 +86,7 @@ def full(
8686
def full_like(
8787
x: Array,
8888
/,
89-
fill_value: complex,
89+
fill_value: bool | int | float | complex,
9090
*,
9191
xp: Namespace,
9292
dtype: Optional[DType] = None,

array_api_compat/cupy/_aliases.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@
6868
# asarray also adds the copy keyword, which is not present in numpy 1.0.
6969
def asarray(
7070
obj: (
71-
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
71+
Array
72+
| bool | int | float | complex
73+
| NestedSequence[bool | int | float | complex]
74+
| SupportsBufferProtocol
7275
),
7376
/,
7477
*,

array_api_compat/dask/array/_aliases.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def arange(
136136
# asarray also adds the copy keyword, which is not present in numpy 1.0.
137137
def asarray(
138138
obj: (
139-
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
139+
Array
140+
| bool | int | float | complex
141+
| NestedSequence[bool | int | float | complex]
142+
| SupportsBufferProtocol
140143
),
141144
/,
142145
*,

array_api_compat/numpy/_aliases.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def _supports_buffer_protocol(obj):
7777
# rather than trying to combine everything into one function in common/
7878
def asarray(
7979
obj: (
80-
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
80+
Array
81+
| bool | int | float | complex
82+
| NestedSequence[bool | int | float | complex]
83+
| SupportsBufferProtocol
8184
),
8285
/,
8386
*,

array_api_compat/torch/_aliases.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True):
116116
_py_scalars = (bool, int, float, complex)
117117

118118

119-
def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType:
119+
def result_type(
120+
*arrays_and_dtypes: Array | DType | bool | int | float | complex
121+
) -> DType:
120122
num = len(arrays_and_dtypes)
121123

122124
if num == 0:
@@ -550,10 +552,16 @@ def count_nonzero(
550552
return result
551553

552554

553-
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
555+
def where(
556+
condition: Array,
557+
x1: Array | bool | int | float | complex,
558+
x2: Array | bool | int | float | complex,
559+
/,
560+
) -> Array:
554561
x1, x2 = _fix_promotion(x1, x2)
555562
return torch.where(condition, x1, x2)
556563

564+
557565
# torch.reshape doesn't have the copy keyword
558566
def reshape(x: Array,
559567
/,
@@ -622,7 +630,7 @@ def linspace(start: Union[int, float],
622630
# torch.full does not accept an int size
623631
# https://github.com/pytorch/pytorch/issues/70906
624632
def full(shape: Union[int, Tuple[int, ...]],
625-
fill_value: complex,
633+
fill_value: bool | int | float | complex,
626634
*,
627635
dtype: Optional[DType] = None,
628636
device: Optional[Device] = None,

array_api_compat/torch/linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def vector_norm(
8585
axis: Optional[Union[int, Tuple[int, ...]]] = None,
8686
keepdims: bool = False,
8787
# float stands for inf | -inf, which are not valid for Literal
88-
ord: Union[int, float, float] = 2,
88+
ord: Union[int, float] = 2,
8989
**kwargs,
9090
) -> Array:
9191
# torch.vector_norm incorrectly treats axis=() the same as axis=None

0 commit comments

Comments
 (0)