Skip to content

Commit c414f5f

Browse files
AzeezIshAzeezIsh
and
AzeezIsh
authored
Trigonometry Testing (#35)
* Added float types method and boolean type. * Incorperated Utility Functions, added checkstyle Checked for all dtype compatibility issues where needed. * Ensuring checkstyle was applied * Adhered to black checkstyle. --------- Co-authored-by: AzeezIsh <[email protected]>
1 parent 5850b9a commit c414f5f

File tree

2 files changed

+164
-3
lines changed

2 files changed

+164
-3
lines changed

tests/test_trig.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtype
6+
import arrayfire_wrapper.lib as wrapper
7+
8+
from . import utility_functions as util
9+
10+
11+
@pytest.mark.parametrize(
12+
"shape",
13+
[
14+
(),
15+
(random.randint(1, 10),),
16+
(random.randint(1, 10), random.randint(1, 10)),
17+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
18+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
19+
],
20+
)
21+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
22+
def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
23+
"""Test inverse sine operation across all supported data types."""
24+
util.check_type_supported(dtype_name)
25+
values = wrapper.randu(shape, dtype_name)
26+
result = wrapper.asin(values)
27+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
28+
29+
30+
@pytest.mark.parametrize(
31+
"shape",
32+
[
33+
(),
34+
(random.randint(1, 10),),
35+
(random.randint(1, 10), random.randint(1, 10)),
36+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
37+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
38+
],
39+
)
40+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
41+
def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
42+
"""Test inverse cosine operation across all supported data types."""
43+
util.check_type_supported(dtype_name)
44+
values = wrapper.randu(shape, dtype_name)
45+
result = wrapper.acos(values)
46+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
47+
48+
49+
@pytest.mark.parametrize(
50+
"shape",
51+
[
52+
(),
53+
(random.randint(1, 10),),
54+
(random.randint(1, 10), random.randint(1, 10)),
55+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
56+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
57+
],
58+
)
59+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
60+
def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
61+
"""Test inverse tan operation across all supported data types."""
62+
util.check_type_supported(dtype_name)
63+
values = wrapper.randu(shape, dtype_name)
64+
result = wrapper.atan(values)
65+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
66+
67+
68+
@pytest.mark.parametrize(
69+
"shape",
70+
[
71+
(),
72+
(random.randint(1, 10),),
73+
(random.randint(1, 10), random.randint(1, 10)),
74+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
75+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
76+
],
77+
)
78+
@pytest.mark.parametrize("dtype_name", util.get_float_types())
79+
def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
80+
"""Test inverse tan operation across all supported data types."""
81+
util.check_type_supported(dtype_name)
82+
if dtype_name == dtype.f16:
83+
pytest.skip()
84+
lhs = wrapper.randu(shape, dtype_name)
85+
rhs = wrapper.randu(shape, dtype_name)
86+
result = wrapper.atan2(lhs, rhs)
87+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
88+
89+
90+
@pytest.mark.parametrize(
91+
"invdtypes",
92+
[
93+
dtype.int16,
94+
dtype.bool,
95+
],
96+
)
97+
def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
98+
"""Test inverse tan operation for unsupported data types."""
99+
with pytest.raises(RuntimeError):
100+
wrapper.atan2(wrapper.randu((10, 10), invdtypes), wrapper.randu((10, 10), invdtypes))
101+
102+
103+
@pytest.mark.parametrize(
104+
"shape",
105+
[
106+
(),
107+
(random.randint(1, 10),),
108+
(random.randint(1, 10), random.randint(1, 10)),
109+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
110+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
111+
],
112+
)
113+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
114+
def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
115+
"""Test cosine operation across all supported data types."""
116+
util.check_type_supported(dtype_name)
117+
values = wrapper.randu(shape, dtype_name)
118+
result = wrapper.cos(values)
119+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
120+
121+
122+
@pytest.mark.parametrize(
123+
"shape",
124+
[
125+
(),
126+
(random.randint(1, 10),),
127+
(random.randint(1, 10), random.randint(1, 10)),
128+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
129+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
130+
],
131+
)
132+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
133+
def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
134+
"""Test sin operation across all supported data types."""
135+
util.check_type_supported(dtype_name)
136+
values = wrapper.randu(shape, dtype_name)
137+
result = wrapper.sin(values)
138+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
139+
140+
141+
@pytest.mark.parametrize(
142+
"shape",
143+
[
144+
(),
145+
(random.randint(1, 10),),
146+
(random.randint(1, 10), random.randint(1, 10)),
147+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
148+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
149+
],
150+
)
151+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
152+
def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
153+
"""Test tan operation across all supported data types."""
154+
util.check_type_supported(dtype_name)
155+
values = wrapper.randu(shape, dtype_name)
156+
result = wrapper.tan(values)
157+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa

tests/utility_functions.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import pytest
22

33
import arrayfire_wrapper.lib as wrapper
4-
from arrayfire_wrapper.dtypes import Dtype, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64
4+
from arrayfire_wrapper.dtypes import Dtype, b8, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64
55

66

77
def check_type_supported(dtype: Dtype) -> None:
88
"""Checks to see if the specified type is supported by the current system"""
99
if dtype in [f64, c64] and not wrapper.get_dbl_support():
1010
pytest.skip("Device does not support double types")
11-
1211
if dtype == f16 and not wrapper.get_half_support():
1312
pytest.skip("Device does not support half types.")
1413

@@ -25,4 +24,9 @@ def get_real_types() -> list:
2524

2625
def get_all_types() -> list:
2726
"""Returns all types"""
28-
return [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
27+
return [b8, s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
28+
29+
30+
def get_float_types() -> list:
31+
"""Returns all types"""
32+
return [f16, f32, f64]

0 commit comments

Comments
 (0)