@@ -448,7 +448,7 @@ def test_iter():
448
448
pytest .raises (TypeError , lambda : iter (ones ((3 , 3 ))))
449
449
450
450
@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
451
- def dlpack_2023_12 (api_version ):
451
+ def test_dlpack_2023_12 (api_version ):
452
452
if api_version == '2021.12' :
453
453
with pytest .warns (UserWarning ):
454
454
set_array_api_strict_flags (api_version = api_version )
@@ -459,25 +459,35 @@ def dlpack_2023_12(api_version):
459
459
# Never an error
460
460
a .__dlpack__ ()
461
461
462
-
463
- if np .__version__ < '2.1' :
464
- exception = NotImplementedError if api_version >= '2023.12' else ValueError
465
- pytest .raises (exception , lambda :
466
- a .__dlpack__ (dl_device = CPU_DEVICE ))
467
- pytest .raises (exception , lambda :
462
+ if api_version < '2023.12' :
463
+ pytest .raises (ValueError , lambda :
464
+ a .__dlpack__ (dl_device = a .__dlpack_device__ ()))
465
+ pytest .raises (ValueError , lambda :
468
466
a .__dlpack__ (dl_device = None ))
469
- pytest .raises (exception , lambda :
467
+ pytest .raises (ValueError , lambda :
470
468
a .__dlpack__ (max_version = (1 , 0 )))
471
- pytest .raises (exception , lambda :
469
+ pytest .raises (ValueError , lambda :
472
470
a .__dlpack__ (max_version = None ))
473
- pytest .raises (exception , lambda :
474
- a .__dlpack__ (copy = False ))
475
- pytest .raises (exception , lambda :
476
- a .__dlpack__ (copy = True ))
477
- pytest .raises (exception , lambda :
478
- a .__dlpack__ (copy = None ))
471
+ pytest .raises (ValueError , lambda :
472
+ a .__dlpack__ (copy = False ))
473
+ pytest .raises (ValueError , lambda :
474
+ a .__dlpack__ (copy = True ))
475
+ pytest .raises (ValueError , lambda :
476
+ a .__dlpack__ (copy = None ))
477
+ elif np .lib .NumpyVersion (np .__version__ ) < '2.1.0' :
478
+ pytest .raises (NotImplementedError , lambda :
479
+ a .__dlpack__ (dl_device = CPU_DEVICE ))
480
+ a .__dlpack__ (dl_device = None )
481
+ pytest .raises (NotImplementedError , lambda :
482
+ a .__dlpack__ (max_version = (1 , 0 )))
483
+ a .__dlpack__ (max_version = None )
484
+ pytest .raises (NotImplementedError , lambda :
485
+ a .__dlpack__ (copy = False ))
486
+ pytest .raises (NotImplementedError , lambda :
487
+ a .__dlpack__ (copy = True ))
488
+ a .__dlpack__ (copy = None )
479
489
else :
480
- a .__dlpack__ (dl_device = CPU_DEVICE )
490
+ a .__dlpack__ (dl_device = a . __dlpack_device__ () )
481
491
a .__dlpack__ (dl_device = None )
482
492
a .__dlpack__ (max_version = (1 , 0 ))
483
493
a .__dlpack__ (max_version = None )
0 commit comments