From 25fbfd011e283c07b77e89751a5de2cd155a2999 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Wed, 13 Mar 2024 10:35:41 -0400 Subject: [PATCH 1/2] utility functions for tests --- tests/utility_functions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/utility_functions.py diff --git a/tests/utility_functions.py b/tests/utility_functions.py new file mode 100644 index 0000000..3e45f78 --- /dev/null +++ b/tests/utility_functions.py @@ -0,0 +1,13 @@ +import pytest + +import arrayfire_wrapper.lib as wrapper +from arrayfire_wrapper.dtypes import Dtype, c64, f16, f64 + + +def check_type_supported(dtype: Dtype) -> None: + """Checks to see if the specified type is supported by the current system""" + if dtype in [f64, c64] and not wrapper.get_dbl_support(): + pytest.skip("Device does not support double types") + + if dtype == f16 and not wrapper.get_half_support(): + pytest.skip("Device does not support half types.") From cae9a3674bb2cbf397f3c6772b27195914cae57e Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Thu, 14 Mar 2024 11:26:58 -0400 Subject: [PATCH 2/2] added additional utility functions for easier access to types --- tests/utility_functions.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/utility_functions.py b/tests/utility_functions.py index 3e45f78..c89bab1 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -1,7 +1,7 @@ import pytest import arrayfire_wrapper.lib as wrapper -from arrayfire_wrapper.dtypes import Dtype, c64, f16, f64 +from arrayfire_wrapper.dtypes import Dtype, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64 def check_type_supported(dtype: Dtype) -> None: @@ -11,3 +11,18 @@ def check_type_supported(dtype: Dtype) -> None: if dtype == f16 and not wrapper.get_half_support(): pytest.skip("Device does not support half types.") + + +def get_complex_types() -> list: + """Returns all complex types""" + return [c32, c64] + + +def get_real_types() -> list: + """Returns all real types""" + return [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64] + + +def get_all_types() -> list: + """Returns all types""" + return [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]