diff --git a/tests/test_trig.py b/tests/test_trig.py index ae62f7b..c81d26e 100644 --- a/tests/test_trig.py +++ b/tests/test_trig.py @@ -4,8 +4,7 @@ import arrayfire_wrapper.dtypes as dtype import arrayfire_wrapper.lib as wrapper - -from . import utility_functions as util +from tests.utility_functions import check_type_supported, get_all_types, get_float_types @pytest.mark.parametrize( @@ -18,10 +17,10 @@ (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_all_types()) +@pytest.mark.parametrize("dtype_name", get_all_types()) def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test inverse sine operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) values = wrapper.randu(shape, dtype_name) result = wrapper.asin(values) assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa @@ -37,10 +36,10 @@ def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_all_types()) +@pytest.mark.parametrize("dtype_name", get_all_types()) def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test inverse cosine operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) values = wrapper.randu(shape, dtype_name) result = wrapper.acos(values) assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa @@ -56,10 +55,10 @@ def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_all_types()) +@pytest.mark.parametrize("dtype_name", get_all_types()) def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test inverse tan operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) values = wrapper.randu(shape, dtype_name) result = wrapper.atan(values) assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa @@ -75,10 +74,10 @@ def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_float_types()) +@pytest.mark.parametrize("dtype_name", get_float_types()) def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test inverse tan operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) if dtype_name == dtype.f16: pytest.skip() lhs = wrapper.randu(shape, dtype_name) @@ -110,10 +109,10 @@ def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None: (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_all_types()) +@pytest.mark.parametrize("dtype_name", get_all_types()) def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test cosine operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) values = wrapper.randu(shape, dtype_name) result = wrapper.cos(values) assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa @@ -129,10 +128,10 @@ def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_all_types()) +@pytest.mark.parametrize("dtype_name", get_all_types()) def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test sin operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) values = wrapper.randu(shape, dtype_name) result = wrapper.sin(values) assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa @@ -148,10 +147,10 @@ def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), ], ) -@pytest.mark.parametrize("dtype_name", util.get_all_types()) +@pytest.mark.parametrize("dtype_name", get_all_types()) def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: """Test tan operation across all supported data types.""" - util.check_type_supported(dtype_name) + check_type_supported(dtype_name) values = wrapper.randu(shape, dtype_name) result = wrapper.tan(values) assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa