forked from data-apis/array-api-strict
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_helpers.py
40 lines (33 loc) · 1.49 KB
/
_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Private helper routines."""
from ._array_object import Array
from ._dtypes import _dtype_categories
from ._flags import get_array_api_strict_flags
_py_scalars = (bool, int, float, complex)
def _maybe_normalize_py_scalars(
x1: Array | bool | int | float | complex,
x2: Array | bool | int | float | complex,
dtype_category: str,
func_name: str,
) -> tuple[Array, Array]:
flags = get_array_api_strict_flags()
if flags["api_version"] < "2024.12":
# scalars will fail at the call site
return x1, x2 # type: ignore[return-value]
_allowed_dtypes = _dtype_categories[dtype_category]
if isinstance(x1, _py_scalars):
if isinstance(x2, _py_scalars):
raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}")
# x2 must be an array
if x2.dtype not in _allowed_dtypes:
raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.")
x1 = x2._promote_scalar(x1)
elif isinstance(x2, _py_scalars):
# x1 must be an array
if x1.dtype not in _allowed_dtypes:
raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.")
x2 = x1._promote_scalar(x2)
else:
if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes:
raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). "
f"Got {x1.dtype} and {x2.dtype}.")
return x1, x2