From ea96e9b2dc186eacaec33df67c68d716f17bfb18 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 22 Nov 2024 11:07:40 +0200 Subject: [PATCH 1/3] ENH: allow python scalars in binary elementwise functions Allow func(array, scalar) and func(scalar, array), raise on func(scalar, scalar) if API_VERSION>=2024.12 cross-ref https://github.com/data-apis/array-api/issues/807 To make sure it is all uniform, 1. Generate all binary "ufuncs" in a uniform way, with a decorator 2. Make binary "ufuncs" follow the same logic of the binary operators 3. Reuse the test loop of Array.__binop__ for binary "ufuncs" 4. (minor) in tests, reuse canonical names for dtype categories ("integer or boolean" vs "integer_or_boolean") --- array_api_strict/_array_object.py | 2 + array_api_strict/_elementwise_functions.py | 584 ++++-------------- array_api_strict/_helpers.py | 37 ++ array_api_strict/tests/test_array_object.py | 101 +-- .../tests/test_elementwise_functions.py | 54 +- 5 files changed, 266 insertions(+), 512 deletions(-) create mode 100644 array_api_strict/_helpers.py diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a917441..47153e5 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -230,6 +230,8 @@ def _check_device(self, other): elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + else: + raise TypeError(f"Cannot combine an Array with {type(other)}.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 7c64f67..3c4b3d8 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -10,17 +10,133 @@ _real_numeric_dtypes, _numeric_dtypes, _result_type, + _dtype_categories as _dtype_dtype_categories, ) from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray from ._data_type_functions import broadcast_to, iinfo +from ._helpers import _maybe_normalize_py_scalars from typing import Optional, Union import numpy as np +def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): + """Base implementation of a binary function, `func_name`, defined for + dtypes from `dtype_category` + """ + x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name) + + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np_func(x1._array, x2._array), device=x1.device) + + +_binary_docstring_template=\ +""" +Array API compatible wrapper for :py:func:`np.%s `. + +See its docstring for more information. +""" + + +def create_binary_func(func_name, dtype_category, np_func): + def inner(x1: Array, x2: Array, /) -> Array: + return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) + return inner + + +# func_name: dtype_category (must match that from _dtypes.py) +_binary_funcs = { + "add": "numeric", + "atan2": "real floating-point", + "bitwise_and": "integer or boolean", + "bitwise_or": "integer or boolean", + "bitwise_xor": "integer or boolean", + "_bitwise_left_shift": "integer", # leading underscore deliberate + "_bitwise_right_shift": "integer", + # XXX: copysign: real fp or numeric? + "copysign": "real floating-point", + "divide": "floating-point", + "equal": "all", + "greater": "real numeric", + "greater_equal": "real numeric", + "less": "real numeric", + "less_equal": "real numeric", + "not_equal": "all", + "floor_divide": "real numeric", + "hypot": "real floating-point", + "logaddexp": "real floating-point", + "logical_and": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "maximum": "real numeric", + "minimum": "real numeric", + "multiply": "numeric", + "nextafter": "real floating-point", + "pow": "numeric", + "remainder": "real numeric", + "subtract": "numeric", +} + + +# map array-api-name : numpy-name +_numpy_renames = { + "atan2": "arctan2", + "_bitwise_left_shift": "left_shift", + "_bitwise_right_shift": "right_shift", + "pow": "power" +} + + +# create and attach functions to the module +for func_name, dtype_category in _binary_funcs.items(): + # sanity check + assert dtype_category in _dtype_dtype_categories + + numpy_name = _numpy_renames.get(func_name, func_name) + np_func = getattr(np, numpy_name) + + func = create_binary_func(func_name, dtype_category, np_func) + func.__name__ = func_name + + func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + + vars()[func_name] = func + + +copysign = requires_api_version('2023.12')(copysign) # noqa: F821 +hypot = requires_api_version('2023.12')(hypot) # noqa: F821 +maximum = requires_api_version('2023.12')(maximum) # noqa: F821 +minimum = requires_api_version('2023.12')(minimum) # noqa: F821 +nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821 + + +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 + if is_negative: + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") + return _bitwise_left_shift(x1, x2) # noqa: F821 +bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 + + +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 + if is_negative: + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") + return _bitwise_right_shift(x1, x2) # noqa: F821 +bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 + + +# clean up to not pollute the namespace +del func, create_binary_func + + def abs(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.abs `. @@ -56,23 +172,6 @@ def acosh(x: Array, /) -> Array: return Array._new(np.arccosh(x._array), device=x.device) -def add(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.add `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in add") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.add(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def asin(x: Array, /) -> Array: """ @@ -109,23 +208,6 @@ def atan(x: Array, /) -> Array: return Array._new(np.arctan(x._array), device=x.device) -# Note: the function name is different here -def atan2(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctan2 `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in atan2") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.arctan2(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def atanh(x: Array, /) -> Array: """ @@ -138,47 +220,6 @@ def atanh(x: Array, /) -> Array: return Array._new(np.arctanh(x._array), device=x.device) -def bitwise_and(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_and `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_and(x1._array, x2._array), device=x1.device) - - -# Note: the function name is different here -def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.left_shift `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # Note: bitwise_left_shift is only defined for x2 nonnegative. - if np.any(x2._array < 0): - raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.left_shift(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def bitwise_invert(x: Array, /) -> Array: """ @@ -191,67 +232,6 @@ def bitwise_invert(x: Array, /) -> Array: return Array._new(np.invert(x._array), device=x.device) -def bitwise_or(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_or `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_or(x1._array, x2._array), device=x1.device) - - -# Note: the function name is different here -def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.right_shift `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # Note: bitwise_right_shift is only defined for x2 nonnegative. - if np.any(x2._array < 0): - raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.right_shift(x1._array, x2._array), device=x1.device) - - -def bitwise_xor(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_xor `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_xor(x1._array, x2._array), device=x1.device) - - def ceil(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.ceil `. @@ -372,6 +352,7 @@ def _isscalar(a): out[ib] = b[ib] return Array._new(out, device=device) + def conj(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.conj `. @@ -382,22 +363,6 @@ def conj(x: Array, /) -> Array: raise TypeError("Only complex floating-point dtypes are allowed in conj") return Array._new(np.conj(x._array), device=x.device) -@requires_api_version('2023.12') -def copysign(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.copysign `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real numeric dtypes are allowed in copysign") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.copysign(x1._array, x2._array), device=x1.device) def cos(x: Array, /) -> Array: """ @@ -421,36 +386,6 @@ def cosh(x: Array, /) -> Array: return Array._new(np.cosh(x._array), device=x.device) -def divide(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.divide `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in divide") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.divide(x1._array, x2._array), device=x1.device) - - -def equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.equal(x1._array, x2._array), device=x1.device) - - def exp(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.exp `. @@ -487,69 +422,6 @@ def floor(x: Array, /) -> Array: return Array._new(np.floor(x._array), device=x.device) -def floor_divide(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.floor_divide `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in floor_divide") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.floor_divide(x1._array, x2._array), device=x1.device) - - -def greater(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.greater `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in greater") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater(x1._array, x2._array), device=x1.device) - - -def greater_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.greater_equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in greater_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater_equal(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def hypot(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.hypot `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in hypot") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.hypot(x1._array, x2._array), device=x1.device) - def imag(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.imag `. @@ -594,38 +466,6 @@ def isnan(x: Array, /) -> Array: return Array._new(np.isnan(x._array), device=x.device) -def less(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.less `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in less") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less(x1._array, x2._array), device=x1.device) - - -def less_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.less_equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in less_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less_equal(x1._array, x2._array), device=x1.device) - - def log(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log `. @@ -670,38 +510,6 @@ def log10(x: Array, /) -> Array: return Array._new(np.log10(x._array), device=x.device) -def logaddexp(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logaddexp `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in logaddexp") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logaddexp(x1._array, x2._array), device=x1.device) - - -def logical_and(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_and `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_and") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_and(x1._array, x2._array), device=x1.device) - - def logical_not(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_not `. @@ -713,87 +521,6 @@ def logical_not(x: Array, /) -> Array: return Array._new(np.logical_not(x._array), device=x.device) -def logical_or(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_or `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_or") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_or(x1._array, x2._array), device=x1.device) - - -def logical_xor(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_xor `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_xor") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_xor(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def maximum(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.maximum `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in maximum") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error - # in that case? - return Array._new(np.maximum(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def minimum(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.minimum `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in minimum") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.minimum(x1._array, x2._array), device=x1.device) - -def multiply(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.multiply `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in multiply") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.multiply(x1._array, x2._array), device=x1.device) - - def negative(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.negative `. @@ -805,34 +532,6 @@ def negative(x: Array, /) -> Array: return Array._new(np.negative(x._array), device=x.device) -@requires_api_version('2024.12') -def nextafter(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.nextafter `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in nextafter") - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.nextafter(x1._array, x2._array), device=x1.device) - -def not_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.not_equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.not_equal(x1._array, x2._array), device=x1.device) - - def positive(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.positive `. @@ -844,23 +543,6 @@ def positive(x: Array, /) -> Array: return Array._new(np.positive(x._array), device=x.device) -# Note: the function name is different here -def pow(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.power `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in pow") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.power(x1._array, x2._array), device=x1.device) - - def real(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.real `. @@ -883,22 +565,6 @@ def reciprocal(x: Array, /) -> Array: raise TypeError("Only floating-point dtypes are allowed in reciprocal") return Array._new(np.reciprocal(x._array), device=x.device) -def remainder(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.remainder `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in remainder") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.remainder(x1._array, x2._array), device=x1.device) - - def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round `. @@ -979,22 +645,6 @@ def sqrt(x: Array, /) -> Array: return Array._new(np.sqrt(x._array), device=x.device) -def subtract(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.subtract `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in subtract") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.subtract(x1._array, x2._array), device=x1.device) - - def tan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tan `. diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py new file mode 100644 index 0000000..2258d29 --- /dev/null +++ b/array_api_strict/_helpers.py @@ -0,0 +1,37 @@ +"""Private helper routines. +""" + +from ._flags import get_array_api_strict_flags +from ._dtypes import _dtype_categories + +_py_scalars = (bool, int, float, complex) + + +def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): + + flags = get_array_api_strict_flags() + if flags["api_version"] < "2024.12": + # scalars will fail at the call site + return x1, x2 + + _allowed_dtypes = _dtype_categories[dtype_category] + + if isinstance(x1, _py_scalars): + if isinstance(x2, _py_scalars): + raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}") + # x2 must be an array + if x2.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.") + x1 = x2._promote_scalar(x1) + + elif isinstance(x2, _py_scalars): + # x1 must be an array + if x1.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.") + x2 = x1._promote_scalar(x2) + else: + if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. " + f"Got {x1.dtype} and {x2.dtype}.") + return x1, x2 + diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 8f185f0..4535d99 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -96,12 +96,60 @@ def test_promoted_scalar_inherits_device(): assert y.device == device1 + +BIG_INT = int(1e30) + +def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for real floating-point array dtypes + # - a Python int, float, or complex for complex floating-point array dtypes + + if ((dtypes == "all" + or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes + or dtypes == "integer" and a.dtype in _integer_dtypes + or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes + or dtypes == "boolean" and a.dtype in _boolean_dtypes + or dtypes == "floating-point" and a.dtype in _floating_dtypes + or dtypes == "real floating-point" and a.dtype in _real_floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and (a.dtype in _boolean_dtypes and type(s) == bool + or a.dtype in _integer_dtypes and type(s) == int + or a.dtype in _real_floating_dtypes and type(s) in [float, int] + or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] + )): + if a.dtype in _integer_dtypes and s == BIG_INT: + with assert_raises(OverflowError): + func(s) + return False + + else: + # Only test for no error + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + func(s) + return True + + else: + with assert_raises(TypeError): + func(s) + return False + + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", - "__and__": "integer_or_boolean", + "__and__": "integer or boolean", "__eq__": "all", "__floordiv__": "real numeric", "__ge__": "real numeric", @@ -112,12 +160,12 @@ def test_operators(): "__mod__": "real numeric", "__mul__": "numeric", "__ne__": "all", - "__or__": "integer_or_boolean", + "__or__": "integer or boolean", "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", - "__truediv__": "floating", - "__xor__": "integer_or_boolean", + "__truediv__": "floating-point", + "__xor__": "integer or boolean", } # Recompute each time because of in-place ops def _array_vals(): @@ -128,8 +176,6 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) - - BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: @@ -139,40 +185,7 @@ def _array_vals(): for s in [1, 1.0, 1j, BIG_INT, False]: for _op in ops: for a in _array_vals(): - # Test array op scalar. From the spec, the following combinations - # are supported: - - # - Python bool for a bool array dtype, - # - a Python int within the bounds of the given dtype for integer array dtypes, - # - a Python int or float for real floating-point array dtypes - # - a Python int, float, or complex for complex floating-point array dtypes - - if ((dtypes == "all" - or dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes - or dtypes == "integer" and a.dtype in _integer_dtypes - or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes - or dtypes == "boolean" and a.dtype in _boolean_dtypes - or dtypes == "floating" and a.dtype in _floating_dtypes - ) - # bool is a subtype of int, which is why we avoid - # isinstance here. - and (a.dtype in _boolean_dtypes and type(s) == bool - or a.dtype in _integer_dtypes and type(s) == int - or a.dtype in _real_floating_dtypes and type(s) in [float, int] - or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] - )): - if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s)) - else: - # Only test for no error - with suppress_warnings() as sup: - # ignore warnings from pow(BIG_INT) - sup.filter(RuntimeWarning, - "invalid value encountered in power") - getattr(a, _op)(s) - else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + _check_op_array_scalar(dtypes, a, s, getattr(a, _op), _op) # Test array op array. for _op in ops: @@ -203,10 +216,10 @@ def _array_vals(): or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes - or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes + or dtypes == "integer or boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes - or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes + or dtypes == "floating-point" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes ): getattr(x, _op)(y) else: @@ -214,7 +227,7 @@ def _array_vals(): unary_op_dtypes = { "__abs__": "numeric", - "__invert__": "integer_or_boolean", + "__invert__": "integer or boolean", "__neg__": "numeric", "__pos__": "numeric", } @@ -223,7 +236,7 @@ def _array_vals(): if ( dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "integer_or_boolean" + or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes ): # Only test for no error diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 4e1b9cc..0b90f0b 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,7 @@ from inspect import signature, getmodule -from numpy.testing import assert_raises +from pytest import raises as assert_raises +from numpy.testing import suppress_warnings import pytest @@ -19,6 +20,8 @@ ) from .._flags import set_array_api_strict_flags +from .test_array_object import _check_op_array_scalar, BIG_INT + import array_api_strict @@ -120,6 +123,7 @@ def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] + mod_funcs = [n for n in mod_funcs if not n.startswith("_")] assert set(mod_funcs) == set(elementwise_function_input_types) @@ -202,3 +206,51 @@ def test_bitwise_shift_error(): assert_raises( ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) ) + + + +def test_scalars(): + # mirror test_array_object.py::test_operators() + # + # Also check that binary functions accept (array, scalar) and (scalar, array) + # arguments, and reject (scalar, scalar) arguments. + + # Use the latest version of the standard so that scalars are actually allowed + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") + + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1.0, dtype=d) + + + for func_name, dtypes in elementwise_function_input_types.items(): + func = getattr(_elementwise_functions, func_name) + if nargs(func) != 2: + continue + + for s in [1, 1.0, 1j, BIG_INT, False]: + for a in _array_vals(): + for func1 in [lambda s: func(a, s), lambda s: func(s, a)]: + allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) + + # only check `func(array, scalar) == `func(array, array)` if + # the former is legal under the promotion rules + if allowed: + conv_scalar = asarray(s, dtype=a.dtype) + + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + assert func(s, a) == func(conv_scalar, a) + assert func(a, s) == func(a, conv_scalar) + + with assert_raises(TypeError): + func(s, s) + + From 33055ce2b0d5560a6275f97214bef030ec00b260 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Dec 2024 16:13:59 +0200 Subject: [PATCH 2/3] add type annotations to binary functions --- array_api_strict/_elementwise_functions.py | 33 +++++++++++++++++----- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 3c4b3d8..54691d6 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -10,7 +10,7 @@ _real_numeric_dtypes, _numeric_dtypes, _result_type, - _dtype_categories as _dtype_dtype_categories, + _dtype_categories, ) from ._array_object import Array from ._flags import requires_api_version @@ -46,11 +46,26 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): def create_binary_func(func_name, dtype_category, np_func): - def inner(x1: Array, x2: Array, /) -> Array: + def inner(x1, x2, /) -> Array: return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) return inner +# static type annotation for ArrayOrPythonScalar arguments given a category +# NB: keep the keys in sync with the _dtype_categories dict +_annotations = { + "all": "bool | int | float | complex | Array", + "real numeric": "int | float | Array", + "numeric": "int | float | complex | Array", + "integer": "int | Array", + "integer or boolean": "int | bool | Array", + "boolean": "bool | Array", + "real floating-point": "float | Array", + "complex floating-point": "complex | Array", + "floating-point": "float | complex | Array", +} + + # func_name: dtype_category (must match that from _dtypes.py) _binary_funcs = { "add": "numeric", @@ -97,7 +112,7 @@ def inner(x1: Array, x2: Array, /) -> Array: # create and attach functions to the module for func_name, dtype_category in _binary_funcs.items(): # sanity check - assert dtype_category in _dtype_dtype_categories + assert dtype_category in _dtype_categories numpy_name = _numpy_renames.get(func_name, func_name) np_func = getattr(np, numpy_name) @@ -106,6 +121,8 @@ def inner(x1: Array, x2: Array, /) -> Array: func.__name__ = func_name func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + func.__annotations__['x1'] = _annotations[dtype_category] + func.__annotations__['x2'] = _annotations[dtype_category] vars()[func_name] = func @@ -117,20 +134,22 @@ def inner(x1: Array, x2: Array, /) -> Array: nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821 -def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: +def bitwise_left_shift(x1: int | Array, x2: int | Array, /) -> Array: is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 if is_negative: raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") return _bitwise_left_shift(x1, x2) # noqa: F821 -bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 +if _bitwise_left_shift.__doc__: # noqa: F821 + bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 -def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: +def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 if is_negative: raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") return _bitwise_right_shift(x1, x2) # noqa: F821 -bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 +if _bitwise_right_shift.__doc__: # noqa: F821 + bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 # clean up to not pollute the namespace From 035cf2da0cb0c492897d6c2f68ffdf282ec55931 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Jan 2025 12:23:44 +0100 Subject: [PATCH 3/3] MAINT: undo the array_object.py change --- array_api_strict/_array_object.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 47153e5..a917441 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -230,8 +230,6 @@ def _check_device(self, other): elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - else: - raise TypeError(f"Cannot combine an Array with {type(other)}.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar):