10
10
_real_numeric_dtypes ,
11
11
_numeric_dtypes ,
12
12
_result_type ,
13
- _dtype_categories as _dtype_dtype_categories ,
13
+ _dtype_categories ,
14
14
)
15
15
from ._array_object import Array
16
16
from ._flags import requires_api_version
@@ -46,11 +46,26 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func):
46
46
47
47
48
48
def create_binary_func (func_name , dtype_category , np_func ):
49
- def inner (x1 : Array , x2 : Array , / ) -> Array :
49
+ def inner (x1 , x2 , / ) -> Array :
50
50
return _binary_ufunc_proto (x1 , x2 , dtype_category , func_name , np_func )
51
51
return inner
52
52
53
53
54
+ # static type annotation for ArrayOrPythonScalar arguments given a category
55
+ # NB: keep the keys in sync with the _dtype_categories dict
56
+ _annotations = {
57
+ "all" : "bool | int | float | complex | Array" ,
58
+ "real numeric" : "int | float | Array" ,
59
+ "numeric" : "int | float | complex | Array" ,
60
+ "integer" : "int | Array" ,
61
+ "integer or boolean" : "int | bool | Array" ,
62
+ "boolean" : "bool | Array" ,
63
+ "real floating-point" : "float | Array" ,
64
+ "complex floating-point" : "complex | Array" ,
65
+ "floating-point" : "float | complex | Array" ,
66
+ }
67
+
68
+
54
69
# func_name: dtype_category (must match that from _dtypes.py)
55
70
_binary_funcs = {
56
71
"add" : "numeric" ,
@@ -97,7 +112,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
97
112
# create and attach functions to the module
98
113
for func_name , dtype_category in _binary_funcs .items ():
99
114
# sanity check
100
- assert dtype_category in _dtype_dtype_categories
115
+ assert dtype_category in _dtype_categories
101
116
102
117
numpy_name = _numpy_renames .get (func_name , func_name )
103
118
np_func = getattr (np , numpy_name )
@@ -106,6 +121,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
106
121
func .__name__ = func_name
107
122
108
123
func .__doc__ = _binary_docstring_template % (numpy_name , numpy_name )
124
+ func .__annotations__ ['x1' ] = _annotations [dtype_category ]
125
+ func .__annotations__ ['x2' ] = _annotations [dtype_category ]
109
126
110
127
vars ()[func_name ] = func
111
128
@@ -117,20 +134,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
117
134
nextafter = requires_api_version ('2024.12' )(nextafter ) # noqa: F821
118
135
119
136
120
- def bitwise_left_shift (x1 : Array , x2 : Array , / ) -> Array :
137
+ def bitwise_left_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
121
138
is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
122
139
if is_negative :
123
140
raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
124
141
return _bitwise_left_shift (x1 , x2 ) # noqa: F821
125
- bitwise_left_shift .__doc__ = _bitwise_left_shift .__doc__ # noqa: F821
142
+ if _bitwise_left_shift .__doc__ : # noqa: F821
143
+ bitwise_left_shift .__doc__ = _bitwise_left_shift .__doc__ # noqa: F821
126
144
127
145
128
- def bitwise_right_shift (x1 : Array , x2 : Array , / ) -> Array :
146
+ def bitwise_right_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
129
147
is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
130
148
if is_negative :
131
149
raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
132
150
return _bitwise_right_shift (x1 , x2 ) # noqa: F821
133
- bitwise_right_shift .__doc__ = _bitwise_right_shift .__doc__ # noqa: F821
151
+ if _bitwise_right_shift .__doc__ : # noqa: F821
152
+ bitwise_right_shift .__doc__ = _bitwise_right_shift .__doc__ # noqa: F821
134
153
135
154
136
155
# clean up to not pollute the namespace
0 commit comments