@@ -366,11 +366,11 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
366
366
skip_if_dtype_not_supported (op1_dtype , q )
367
367
skip_if_dtype_not_supported (op2_dtype , q )
368
368
369
- if dpt .can_cast (op2_dtype , op1_dtype , casting = "safe" ):
370
- sz = 127
371
- ar1 = dpt .ones (sz , dtype = op1_dtype )
372
- ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
369
+ sz = 127
370
+ ar1 = dpt .ones (sz , dtype = op1_dtype )
371
+ ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
373
372
373
+ if dpt .can_cast (op2_dtype , op1_dtype , casting = "safe" ):
374
374
ar1 += ar2
375
375
assert (
376
376
dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = ar1 .dtype )
@@ -385,7 +385,8 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
385
385
).all ()
386
386
387
387
else :
388
- assert pytest .raises (TypeError )
388
+ with pytest .raises (TypeError ):
389
+ ar1 += ar2
389
390
390
391
391
392
def test_add_inplace_broadcasting ():
@@ -396,3 +397,40 @@ def test_add_inplace_broadcasting():
396
397
397
398
m += v
398
399
assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
400
+
401
+
402
+ def test_add_inplace_errors ():
403
+ get_queue_or_skip ()
404
+ try :
405
+ gpu_queue = dpctl .SyclQueue ("gpu" )
406
+ except dpctl .SyclQueueCreationError :
407
+ pytest .skip ("SyclQueue('gpu') failed, skipping" )
408
+ try :
409
+ cpu_queue = dpctl .SyclQueue ("cpu" )
410
+ except dpctl .SyclQueueCreationError :
411
+ pytest .skip ("SyclQueue('cpu') failed, skipping" )
412
+
413
+ ar1 = dpt .ones (2 , dtype = "float32" , sycl_queue = gpu_queue )
414
+ ar2 = dpt .ones_like (ar1 , sycl_queue = cpu_queue )
415
+ with pytest .raises (ExecutionPlacementError ):
416
+ ar1 += ar2
417
+
418
+ ar1 = dpt .ones (2 , dtype = "float32" )
419
+ ar2 = dpt .ones (3 , dtype = "float32" )
420
+ with pytest .raises (ValueError ):
421
+ ar1 += ar2
422
+
423
+ ar1 = np .ones (2 , dtype = "float32" )
424
+ ar2 = dpt .ones (2 , dtype = "float32" )
425
+ with pytest .raises (TypeError ):
426
+ ar1 += ar2
427
+
428
+ ar1 = dpt .ones (2 , dtype = "float32" )
429
+ ar2 = dict ()
430
+ with pytest .raises (ValueError ):
431
+ ar1 += ar2
432
+
433
+ ar1 = dpt .ones ((2 , 1 ), dtype = "float32" )
434
+ ar2 = dpt .ones ((1 , 2 ), dtype = "float32" )
435
+ with pytest .raises (ValueError ):
436
+ ar1 += ar2
0 commit comments