@@ -1783,48 +1783,42 @@ def test_trunc(x):
1783
1783
1784
1784
def _check_binary_with_scalars (func_data , x1x2 ):
1785
1785
x1 , x2 = x1x2
1786
- func , name , refimpl , kwds , expected_dtype = func_data
1786
+ func_name , refimpl , kwds , expected_dtype = func_data
1787
+ func = getattr (xp , func_name )
1787
1788
out = func (x1 , x2 )
1788
1789
in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1789
1790
_assert_correctness_binary (
1790
- name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
1791
+ func_name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
1791
1792
)
1792
1793
1793
1794
1794
1795
def _filter_zero (x ):
1795
1796
return x != 0 if dh .is_scalar (x ) else (not xp .any (x == 0 ))
1796
1797
1797
- # workarounds for xp.copysign etc only available in 2023.12
1798
- # Without it, test suite fails to import with ARRAY_API_VERSION=2022.12
1799
- _xp_copysign = getattr (xp , "copysign" , None )
1800
- _xp_hypot = getattr (xp , "hypot" , None )
1801
- _xp_maximum = getattr (xp , "maximum" , None )
1802
- _xp_minimum = getattr (xp , "minimum" , None )
1803
-
1804
1798
1805
1799
@pytest .mark .min_version ("2024.12" )
1806
1800
@pytest .mark .parametrize ('func_data' ,
1807
- # xp_func, name , refimpl, kwargs, expected_dtype
1801
+ # func_name , refimpl, kwargs, expected_dtype
1808
1802
[
1809
- (xp . add , "add" , operator .add , {}, None ),
1810
- (xp . atan2 , "atan2" , math .atan2 , {}, None ),
1811
- (_xp_copysign , "copysign" , math .copysign , {}, None ),
1812
- (xp . divide , "divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1813
- (_xp_hypot , "hypot" , math .hypot , {}, None ),
1814
- (xp . logaddexp , "logaddexp" , logaddexp_refimpl , {}, None ),
1815
- (_xp_maximum , "maximum" , max , {'strict_check' : True }, None ),
1816
- (_xp_minimum , "minimum" , min , {'strict_check' : True }, None ),
1817
- (xp . multiply , "mul " , operator .mul , {}, None ),
1818
- (xp . subtract , "sub " , operator .sub , {}, None ),
1819
-
1820
- (xp . equal , "equal" , operator .eq , {}, xp .bool ),
1821
- (xp . not_equal , "neq " , operator .ne , {}, xp .bool ),
1822
- (xp . less , "less" , operator .lt , {}, xp .bool ),
1823
- (xp . less_equal , "les_equal " , operator .le , {}, xp .bool ),
1824
- (xp . greater , "greater" , operator .gt , {}, xp .bool ),
1825
- (xp . greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1803
+ ("add" , operator .add , {}, None ),
1804
+ ("atan2" , math .atan2 , {}, None ),
1805
+ ("copysign" , math .copysign , {}, None ),
1806
+ ("divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1807
+ ("hypot" , math .hypot , {}, None ),
1808
+ ("logaddexp" , logaddexp_refimpl , {}, None ),
1809
+ ("maximum" , max , {'strict_check' : True }, None ),
1810
+ ("minimum" , min , {'strict_check' : True }, None ),
1811
+ ("multiply " , operator .mul , {}, None ),
1812
+ ("subtract " , operator .sub , {}, None ),
1813
+
1814
+ ("equal" , operator .eq , {}, xp .bool ),
1815
+ ("not_equal " , operator .ne , {}, xp .bool ),
1816
+ ("less" , operator .lt , {}, xp .bool ),
1817
+ ("less_equal " , operator .le , {}, xp .bool ),
1818
+ ("greater" , operator .gt , {}, xp .bool ),
1819
+ ("greater_equal" , operator .ge , {}, xp .bool ),
1826
1820
],
1827
- ids = lambda func_data : func_data [1 ] # use names for test IDs
1821
+ ids = lambda func_data : func_data [0 ] # use names for test IDs
1828
1822
)
1829
1823
@given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
1830
1824
def test_binary_with_scalars_real (func_data , x1x2 ):
@@ -1833,13 +1827,13 @@ def test_binary_with_scalars_real(func_data, x1x2):
1833
1827
1834
1828
@pytest .mark .min_version ("2024.12" )
1835
1829
@pytest .mark .parametrize ('func_data' ,
1836
- # xp_func, name , refimpl, kwargs, expected_dtype
1830
+ # func_name , refimpl, kwargs, expected_dtype
1837
1831
[
1838
- (xp . logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1839
- (xp . logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1840
- (xp . logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
1832
+ ("logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1833
+ ("logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1834
+ ("logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
1841
1835
],
1842
- ids = lambda func_data : func_data [1 ] # use names for test IDs
1836
+ ids = lambda func_data : func_data [0 ] # use names for test IDs
1843
1837
)
1844
1838
@given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
1845
1839
def test_binary_with_scalars_bool (func_data , x1x2 ):
@@ -1848,36 +1842,34 @@ def test_binary_with_scalars_bool(func_data, x1x2):
1848
1842
1849
1843
@pytest .mark .min_version ("2024.12" )
1850
1844
@pytest .mark .parametrize ('func_data' ,
1851
- # xp_func, name , refimpl, kwargs, expected_dtype
1845
+ # func_name , refimpl, kwargs, expected_dtype
1852
1846
[
1853
-
1854
- (xp .floor_divide , "floor_divide" , operator .floordiv , {}, None ),
1855
- (xp .remainder , "remainder" , operator .mod , {}, None ),
1847
+ ("floor_divide" , operator .floordiv , {}, None ),
1848
+ ("remainder" , operator .mod , {}, None ),
1856
1849
],
1857
- ids = lambda func_data : func_data [1 ] # use names for test IDs
1850
+ ids = lambda func_data : func_data [0 ] # use names for test IDs
1858
1851
)
1859
1852
@given (x1x2 = hh .array_and_py_scalar ([xp .int64 ]))
1860
1853
def test_binary_with_scalars_int (func_data , x1x2 ):
1861
-
1862
1854
assume (_filter_zero (x1x2 [1 ]))
1863
1855
assume (_filter_zero (x1x2 [0 ]) and _filter_zero (x1x2 [1 ]))
1864
1856
_check_binary_with_scalars (func_data , x1x2 )
1865
1857
1866
1858
1867
1859
@pytest .mark .min_version ("2024.12" )
1868
1860
@pytest .mark .parametrize ('func_data' ,
1869
- # xp_func, name , refimpl, kwargs, expected_dtype
1861
+ # func_name , refimpl, kwargs, expected_dtype
1870
1862
[
1871
- (xp . bitwise_and , "bitwise_and" , operator .and_ , {}, None ),
1872
- (xp . bitwise_or , "bitwise_or" , operator .or_ , {}, None ),
1873
- (xp . bitwise_xor , "bitwise_xor" , operator .xor , {}, None ),
1863
+ ("bitwise_and" , operator .and_ , {}, None ),
1864
+ ("bitwise_or" , operator .or_ , {}, None ),
1865
+ ("bitwise_xor" , operator .xor , {}, None ),
1874
1866
],
1875
- ids = lambda func_data : func_data [1 ] # use names for test IDs
1867
+ ids = lambda func_data : func_data [0 ] # use names for test IDs
1876
1868
)
1877
1869
@given (x1x2 = hh .array_and_py_scalar ([xp .int32 ]))
1878
1870
def test_binary_with_scalars_bitwise (func_data , x1x2 ):
1879
- xp_func , name , refimpl , kwargs , expected = func_data
1871
+ func_name , refimpl , kwargs , expected = func_data
1880
1872
# repack the refimpl
1881
1873
refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1882
- _check_binary_with_scalars ((xp_func , name , refimpl_ , kwargs ,expected ), x1x2 )
1874
+ _check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
1883
1875
0 commit comments