Skip to content

Commit 17d8018

Browse files
committed
add type annotations to binary functions
1 parent 5ddba5b commit 17d8018

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

array_api_strict/_elementwise_functions.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
_real_numeric_dtypes,
1111
_numeric_dtypes,
1212
_result_type,
13-
_dtype_categories as _dtype_dtype_categories,
13+
_dtype_categories,
1414
)
1515
from ._array_object import Array
1616
from ._flags import requires_api_version
@@ -46,11 +46,26 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func):
4646

4747

4848
def create_binary_func(func_name, dtype_category, np_func):
49-
def inner(x1: Array, x2: Array, /) -> Array:
49+
def inner(x1, x2, /) -> Array:
5050
return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func)
5151
return inner
5252

5353

54+
# static type annotation for ArrayOrPythonScalar arguments given a category
55+
# NB: keep the keys in sync with the _dtype_categories dict
56+
_annotations = {
57+
"all": "bool | int | float | complex | Array",
58+
"real numeric": "int | float | Array",
59+
"numeric": "int | float | complex | Array",
60+
"integer": "int | Array",
61+
"integer or boolean": "int | bool | Array",
62+
"boolean": "bool | Array",
63+
"real floating-point": "float | Array",
64+
"complex floating-point": "complex | Array",
65+
"floating-point": "float | complex | Array",
66+
}
67+
68+
5469
# func_name: dtype_category (must match that from _dtypes.py)
5570
_binary_funcs = {
5671
"add": "numeric",
@@ -97,7 +112,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
97112
# create and attach functions to the module
98113
for func_name, dtype_category in _binary_funcs.items():
99114
# sanity check
100-
assert dtype_category in _dtype_dtype_categories
115+
assert dtype_category in _dtype_categories
101116

102117
numpy_name = _numpy_renames.get(func_name, func_name)
103118
np_func = getattr(np, numpy_name)
@@ -106,6 +121,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
106121
func.__name__ = func_name
107122

108123
func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name)
124+
func.__annotations__['x1'] = _annotations[dtype_category]
125+
func.__annotations__['x2'] = _annotations[dtype_category]
109126

110127
vars()[func_name] = func
111128

@@ -117,20 +134,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
117134
nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821
118135

119136

120-
def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
137+
def bitwise_left_shift(x1: int | Array, x2: int | Array, /) -> Array:
121138
is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
122139
if is_negative:
123140
raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
124141
return _bitwise_left_shift(x1, x2) # noqa: F821
125-
bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821
142+
if _bitwise_left_shift.__doc__: # noqa: F821
143+
bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821
126144

127145

128-
def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
146+
def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array:
129147
is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
130148
if is_negative:
131149
raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
132150
return _bitwise_right_shift(x1, x2) # noqa: F821
133-
bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821
151+
if _bitwise_right_shift.__doc__: # noqa: F821
152+
bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821
134153

135154

136155
# clean up to not pollute the namespace

0 commit comments

Comments
 (0)