@@ -411,72 +411,47 @@ def _matmul_array_vals():
411
411
x .__imatmul__ (y )
412
412
413
413
414
- @pytest .mark .parametrize (
415
- "op" ,
416
- [
417
- op for op , dtypes in binary_op_dtypes .items ()
418
- if dtypes not in ("real numeric" , "floating-point" )
419
- ],
420
- )
421
- def test_binary_operators_vs_numpy_int (op ):
422
- """np.int64 is not a subclass of int and must be disallowed"""
423
- a = asarray (1 )
424
- i64 = np .int64 (1 )
425
- with pytest .raises (TypeError , match = "Expected Array or Python scalar" ):
426
- getattr (a , op )(i64 )
427
-
428
-
429
- @pytest .mark .parametrize (
430
- "op" ,
431
- [
432
- op for op , dtypes in binary_op_dtypes .items ()
433
- if dtypes not in ("integer" , "integer or boolean" )
434
- ],
435
- )
436
- def test_binary_operators_vs_numpy_float (op ):
437
- """
438
- np.float64 is a subclass of float and must be allowed.
439
- np.float32 is not and must be rejected.
440
- """
441
- a = asarray (1. )
442
- f64 = np .float64 (1. )
443
- f32 = np .float32 (1. )
444
- func = getattr (a , op )
445
- for op in binary_op_dtypes :
446
- assert isinstance (func (f64 ), Array )
447
- with pytest .raises (TypeError , match = "Expected Array or Python scalar" ):
448
- func (f32 )
449
-
450
-
451
- @pytest .mark .parametrize (
452
- "op" ,
453
- [
454
- op for op , dtypes in binary_op_dtypes .items ()
455
- if dtypes not in ("integer" , "integer or boolean" , "real numeric" )
456
- ],
457
- )
458
- def test_binary_operators_vs_numpy_complex (op ):
459
- """
460
- np.complex128 is a subclass of complex and must be allowed.
461
- np.complex64 is not and must be rejected.
414
+ @pytest .mark .parametrize ("op,dtypes" , binary_op_dtypes .items ())
415
+ def test_binary_operators_vs_numpy_generics (op , dtypes ):
416
+ """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128
417
+ are disallowed in binary operators.
418
+ np.float64 and np.complex128 are subclasses of float and complex, so they need
419
+ special treatment in order to be rejected.
462
420
"""
463
- a = asarray (1. )
464
- c64 = np .complex64 (1. )
465
- c128 = np .complex128 (1. )
466
- func = getattr (a , op )
467
- for op in binary_op_dtypes :
468
- assert isinstance (func (c128 ), Array )
469
- with pytest .raises (TypeError , match = "Expected Array or Python scalar" ):
470
- func (c64 )
421
+ match = "Expected Array or Python scalar"
422
+
423
+ if dtypes not in ("numeric" , "integer" , "real numeric" , "floating-point" ):
424
+ a = asarray (True )
425
+ func = getattr (a , op )
426
+ with pytest .raises (TypeError , match = match ):
427
+ func (np .bool_ (True ))
428
+
429
+ if dtypes != "floating-point" :
430
+ a = asarray (1 )
431
+ func = getattr (a , op )
432
+ with pytest .raises (TypeError , match = match ):
433
+ func (np .int64 (1 ))
434
+
435
+ if dtypes not in ("integer" , "integer or boolean" ):
436
+ a = asarray (1. ,)
437
+ func = getattr (a , op )
438
+ with pytest .raises (TypeError , match = match ):
439
+ func (np .float32 (1. ))
440
+ with pytest .raises (TypeError , match = match ):
441
+ func (np .float64 (1. ))
442
+
443
+ if dtypes not in ("integer" , "integer or boolean" , "real numeric" ):
444
+ a = asarray (1. ,)
445
+ func = getattr (a , op )
446
+ with pytest .raises (TypeError , match = match ):
447
+ func (np .complex64 (1. ))
448
+ with pytest .raises (TypeError , match = match ):
449
+ func (np .complex128 (1. ))
471
450
472
451
473
452
@pytest .mark .parametrize ("op,dtypes" , binary_op_dtypes .items ())
474
453
def test_binary_operators_device_mismatch (op , dtypes ):
475
- if dtypes in ("real numeric" , "floating-point" ):
476
- dtype = float64
477
- else :
478
- dtype = int64
479
-
454
+ dtype = float64 if dtypes == "floating-point" else int64
480
455
a = asarray (1 , dtype = dtype , device = CPU_DEVICE )
481
456
b = asarray (1 , dtype = dtype , device = Device ("device1" ))
482
457
with pytest .raises (ValueError , match = "different devices" ):
0 commit comments