@@ -690,6 +690,38 @@ def binary_param_assert_against_refimpl(
690
690
)
691
691
692
692
693
+ def _convert_scalars_helper (x1 , x2 ):
694
+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
695
+
696
+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697
+ and all arguments converted to arrays.
698
+
699
+ dtypes are separate to help distinguishing between
700
+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701
+ """
702
+ if dh .is_scalar (x1 ):
703
+ in_dtypes = [x2 .dtype ]
704
+ in_shapes = [x2 .shape ]
705
+ x1a , x2a = xp .asarray (x1 ), x2
706
+ elif dh .is_scalar (x2 ):
707
+ in_dtypes = [x1 .dtype ]
708
+ in_shapes = [x1 .shape ]
709
+ x1a , x2a = x1 , xp .asarray (x2 )
710
+ else :
711
+ in_dtypes = [x1 .dtype , x2 .dtype ]
712
+ in_shapes = [x1 .shape , x2 .shape ]
713
+ x1a , x2a = x1 , x2
714
+
715
+ return in_dtypes , in_shapes , (x1a , x2a )
716
+
717
+
718
+ def _assert_correctness_binary (name , func , in_dtypes , in_shapes , in_arrs , out , ** kwargs ):
719
+ x1a , x2a = in_arrs
720
+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype )
721
+ ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
722
+ binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
723
+
724
+
693
725
@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
694
726
@given (data = st .data ())
695
727
def test_abs (ctx , data ):
@@ -789,10 +821,14 @@ def test_atan(x):
789
821
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
790
822
def test_atan2 (x1 , x2 ):
791
823
out = xp .atan2 (x1 , x2 )
792
- ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
793
- ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
794
- refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
795
- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
824
+ _assert_correctness_binary (
825
+ "atan" ,
826
+ cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2 ,
827
+ in_dtypes = [x1 .dtype , x2 .dtype ],
828
+ in_shapes = [x1 .shape , x2 .shape ],
829
+ in_arrs = [x1 , x2 ],
830
+ out = out ,
831
+ )
796
832
797
833
798
834
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1258,10 +1294,14 @@ def test_greater_equal(ctx, data):
1258
1294
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1259
1295
def test_hypot (x1 , x2 ):
1260
1296
out = xp .hypot (x1 , x2 )
1261
- ph .assert_dtype ("hypot" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1262
- ph .assert_result_shape ("hypot" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1263
- binary_assert_against_refimpl ("hypot" , x1 , x2 , out , math .hypot )
1264
-
1297
+ _assert_correctness_binary (
1298
+ "hypot" ,
1299
+ math .hypot ,
1300
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1301
+ in_shapes = [x1 .shape , x2 .shape ],
1302
+ in_arrs = [x1 , x2 ],
1303
+ out = out
1304
+ )
1265
1305
1266
1306
1267
1307
@pytest .mark .min_version ("2022.12" )
@@ -1411,21 +1451,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
1411
1451
raise OverflowError
1412
1452
1413
1453
1454
+ @pytest .mark .min_version ("2023.12" )
1414
1455
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1415
1456
def test_logaddexp (x1 , x2 ):
1416
1457
out = xp .logaddexp (x1 , x2 )
1417
- ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1418
- ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1419
- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1420
-
1421
-
1422
- @given (* hh .two_mutual_arrays ([xp .bool ]))
1423
- def test_logical_and (x1 , x2 ):
1424
- out = xp .logical_and (x1 , x2 )
1425
- ph .assert_dtype ("logical_and" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1426
- ph .assert_result_shape ("logical_and" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1427
- binary_assert_against_refimpl (
1428
- "logical_and" , x1 , x2 , out , operator .and_ , expr_template = "({} and {})={}"
1458
+ _assert_correctness_binary (
1459
+ "logaddexp" ,
1460
+ logaddexp_refimpl ,
1461
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1462
+ in_shapes = [x1 .shape , x2 .shape ],
1463
+ in_arrs = [x1 , x2 ],
1464
+ out = out
1429
1465
)
1430
1466
1431
1467
@@ -1439,42 +1475,64 @@ def test_logical_not(x):
1439
1475
)
1440
1476
1441
1477
1478
+ @given (* hh .two_mutual_arrays ([xp .bool ]))
1479
+ def test_logical_and (x1 , x2 ):
1480
+ out = xp .logical_and (x1 , x2 )
1481
+ _assert_correctness_binary (
1482
+ "logical_and" ,
1483
+ operator .and_ ,
1484
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1485
+ in_shapes = [x1 .shape , x2 .shape ],
1486
+ in_arrs = [x1 , x2 ],
1487
+ out = out ,
1488
+ expr_template = "({} and {})={}"
1489
+ )
1490
+
1491
+
1442
1492
@given (* hh .two_mutual_arrays ([xp .bool ]))
1443
1493
def test_logical_or (x1 , x2 ):
1444
1494
out = xp .logical_or (x1 , x2 )
1445
- ph .assert_dtype ("logical_or" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1446
- ph .assert_result_shape ("logical_or" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1447
- binary_assert_against_refimpl (
1448
- "logical_or" , x1 , x2 , out , operator .or_ , expr_template = "({} or {})={}"
1495
+ _assert_correctness_binary (
1496
+ "logical_or" ,
1497
+ operator .or_ ,
1498
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1499
+ in_shapes = [x1 .shape , x2 .shape ],
1500
+ in_arrs = [x1 , x2 ],
1501
+ out = out ,
1502
+ expr_template = "({} or {})={}"
1449
1503
)
1450
1504
1451
1505
1452
1506
@given (* hh .two_mutual_arrays ([xp .bool ]))
1453
1507
def test_logical_xor (x1 , x2 ):
1454
1508
out = xp .logical_xor (x1 , x2 )
1455
- ph .assert_dtype ("logical_xor" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1456
- ph .assert_result_shape ("logical_xor" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1457
- binary_assert_against_refimpl (
1458
- "logical_xor" , x1 , x2 , out , operator .xor , expr_template = "({} ^ {})={}"
1509
+ _assert_correctness_binary (
1510
+ "logical_xor" ,
1511
+ operator .xor ,
1512
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1513
+ in_shapes = [x1 .shape , x2 .shape ],
1514
+ in_arrs = [x1 , x2 ],
1515
+ out = out ,
1516
+ expr_template = "({} ^ {})={}"
1459
1517
)
1460
1518
1461
1519
1462
1520
@pytest .mark .min_version ("2023.12" )
1463
1521
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1464
1522
def test_maximum (x1 , x2 ):
1465
1523
out = xp .maximum (x1 , x2 )
1466
- ph . assert_dtype ( "maximum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1467
- ph . assert_result_shape ( "maximum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1468
- binary_assert_against_refimpl ( "maximum" , x1 , x2 , out , max , strict_check = True )
1524
+ _assert_correctness_binary (
1525
+ "maximum" , max , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1526
+ )
1469
1527
1470
1528
1471
1529
@pytest .mark .min_version ("2023.12" )
1472
1530
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1473
1531
def test_minimum (x1 , x2 ):
1474
1532
out = xp .minimum (x1 , x2 )
1475
- ph . assert_dtype ( "minimum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1476
- ph . assert_result_shape ( "minimum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1477
- binary_assert_against_refimpl ( "minimum" , x1 , x2 , out , min , strict_check = True )
1533
+ _assert_correctness_binary (
1534
+ "minimum" , min , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1535
+ )
1478
1536
1479
1537
1480
1538
@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
@@ -1719,3 +1777,45 @@ def test_trunc(x):
1719
1777
ph .assert_dtype ("trunc" , in_dtype = x .dtype , out_dtype = out .dtype )
1720
1778
ph .assert_shape ("trunc" , out_shape = out .shape , expected = x .shape )
1721
1779
unary_assert_against_refimpl ("trunc" , x , out , math .trunc , strict_check = True )
1780
+
1781
+
1782
+ def _check_binary_with_scalars (func_data , x1x2 ):
1783
+ x1 , x2 = x1x2
1784
+ func , name , refimpl , kwds = func_data
1785
+ out = func (x1 , x2 )
1786
+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1787
+ _assert_correctness_binary (
1788
+ name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , ** kwds
1789
+ )
1790
+
1791
+
1792
+ @pytest .mark .min_version ("2024.12" )
1793
+ @pytest .mark .parametrize ('func_data' ,
1794
+ # xp_func, name, refimpl, kwargs
1795
+ [
1796
+ (xp .atan2 , "atan2" , math .atan2 , {}),
1797
+ (xp .hypot , "hypot" , math .hypot , {}),
1798
+ (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}),
1799
+ (xp .maximum , "maximum" , max , {'strict_check' : True }),
1800
+ (xp .minimum , "minimum" , min , {'strict_check' : True }),
1801
+ ],
1802
+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1803
+ )
1804
+ @given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
1805
+ def test_binary_with_scalars_real (func_data , x1x2 ):
1806
+ _check_binary_with_scalars (func_data , x1x2 )
1807
+
1808
+
1809
+ @pytest .mark .min_version ("2024.12" )
1810
+ @pytest .mark .parametrize ('func_data' ,
1811
+ # xp_func, name, refimpl, kwargs
1812
+ [
1813
+ (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }),
1814
+ (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }),
1815
+ (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }),
1816
+ ],
1817
+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1818
+ )
1819
+ @given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
1820
+ def test_binary_with_scalars_bool (func_data , x1x2 ):
1821
+ _check_binary_with_scalars (func_data , x1x2 )
0 commit comments