@@ -715,9 +715,11 @@ def _convert_scalars_helper(x1, x2):
715
715
return in_dtypes , in_shapes , (x1a , x2a )
716
716
717
717
718
- def _assert_correctness_binary (name , func , in_dtypes , in_shapes , in_arrs , out , ** kwargs ):
718
+ def _assert_correctness_binary (
719
+ name , func , in_dtypes , in_shapes , in_arrs , out , expected_dtype = None , ** kwargs
720
+ ):
719
721
x1a , x2a = in_arrs
720
- ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype )
722
+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype , expected = expected_dtype )
721
723
ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
722
724
binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
723
725
@@ -1781,23 +1783,35 @@ def test_trunc(x):
1781
1783
1782
1784
def _check_binary_with_scalars (func_data , x1x2 ):
1783
1785
x1 , x2 = x1x2
1784
- func , name , refimpl , kwds = func_data
1786
+ func , name , refimpl , kwds , expected_dtype = func_data
1785
1787
out = func (x1 , x2 )
1786
1788
in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1787
1789
_assert_correctness_binary (
1788
- name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , ** kwds
1790
+ name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
1789
1791
)
1790
1792
1791
1793
1792
1794
@pytest .mark .min_version ("2024.12" )
1793
1795
@pytest .mark .parametrize ('func_data' ,
1794
- # xp_func, name, refimpl, kwargs
1796
+ # xp_func, name, refimpl, kwargs, expected_dtype
1795
1797
[
1796
- (xp .atan2 , "atan2" , math .atan2 , {}),
1797
- (xp .hypot , "hypot" , math .hypot , {}),
1798
- (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}),
1799
- (xp .maximum , "maximum" , max , {'strict_check' : True }),
1800
- (xp .minimum , "minimum" , min , {'strict_check' : True }),
1798
+ (xp .add , "add" , operator .add , {}, None ),
1799
+ (xp .atan2 , "atan2" , math .atan2 , {}, None ),
1800
+ (xp .copysign , "copysign" , math .copysign , {}, None ),
1801
+ (xp .divide , "divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1802
+ (xp .hypot , "hypot" , math .hypot , {}, None ),
1803
+ (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}, None ),
1804
+ (xp .maximum , "maximum" , max , {'strict_check' : True }, None ),
1805
+ (xp .minimum , "minimum" , min , {'strict_check' : True }, None ),
1806
+ (xp .multiply , "mul" , operator .mul , {}, None ),
1807
+ (xp .subtract , "sub" , operator .sub , {}, None ),
1808
+
1809
+ (xp .equal , "equal" , operator .eq , {}, xp .bool ),
1810
+ (xp .not_equal , "neq" , operator .ne , {}, xp .bool ),
1811
+ (xp .less , "less" , operator .lt , {}, xp .bool ),
1812
+ (xp .less_equal , "les_equal" , operator .le , {}, xp .bool ),
1813
+ (xp .greater , "greater" , operator .gt , {}, xp .bool ),
1814
+ (xp .greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1801
1815
],
1802
1816
ids = lambda func_data : func_data [1 ] # use names for test IDs
1803
1817
)
@@ -1808,14 +1822,15 @@ def test_binary_with_scalars_real(func_data, x1x2):
1808
1822
1809
1823
@pytest .mark .min_version ("2024.12" )
1810
1824
@pytest .mark .parametrize ('func_data' ,
1811
- # xp_func, name, refimpl, kwargs
1825
+ # xp_func, name, refimpl, kwargs, expected_dtype
1812
1826
[
1813
- (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }),
1814
- (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }),
1815
- (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }),
1827
+ (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1828
+ (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1829
+ (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
1816
1830
],
1817
1831
ids = lambda func_data : func_data [1 ] # use names for test IDs
1818
1832
)
1819
1833
@given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
1820
1834
def test_binary_with_scalars_bool (func_data , x1x2 ):
1821
1835
_check_binary_with_scalars (func_data , x1x2 )
1836
+
0 commit comments