Skip to content

Commit 97933ed

Browse files
author
AzeezIsh
committed
Ensured changes were applied.
1 parent 3b4938c commit 97933ed

File tree

1 file changed

+5
-35
lines changed

1 file changed

+5
-35
lines changed

tests/test_arithmetic.py

+5-35
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import random
22

3-
# import numpy as np
43
import pytest
54

65
import arrayfire_wrapper.dtypes as dtype
76
import arrayfire_wrapper.lib as wrapper
8-
9-
# from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string
7+
from tests.utility_functions import check_type_supported, get_all_types
108

119

1210
@pytest.mark.parametrize(
@@ -41,39 +39,10 @@ def test_add_different_shapes() -> None:
4139
wrapper.add(lhs, rhs)
4240

4341

44-
dtype_map = {
45-
"int16": dtype.s16,
46-
"int32": dtype.s32,
47-
"int64": dtype.s64,
48-
"uint8": dtype.u8,
49-
"uint16": dtype.u16,
50-
"uint32": dtype.u32,
51-
"uint64": dtype.u64,
52-
"float16": dtype.f16,
53-
"float32": dtype.f32,
54-
# 'float64': dtype.f64,
55-
# 'complex64': dtype.c64,
56-
"complex32": dtype.c32,
57-
"bool": dtype.b8,
58-
"s16": dtype.s16,
59-
"s32": dtype.s32,
60-
"s64": dtype.s64,
61-
"u8": dtype.u8,
62-
"u16": dtype.u16,
63-
"u32": dtype.u32,
64-
"u64": dtype.u64,
65-
"f16": dtype.f16,
66-
"f32": dtype.f32,
67-
# 'f64': dtype.f64,
68-
"c32": dtype.c32,
69-
# 'c64': dtype.c64,
70-
"b8": dtype.b8,
71-
}
72-
73-
74-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
42+
@pytest.mark.parametrize("dtype_name", get_all_types())
7543
def test_add_supported_dtypes(dtype_name: dtype.Dtype) -> None:
7644
"""Test addition operation across all supported data types."""
45+
check_type_supported(dtype_name)
7746
shape = (5, 5) # Using a common shape for simplicity
7847
lhs = wrapper.randu(shape, dtype_name)
7948
rhs = wrapper.randu(shape, dtype_name)
@@ -143,9 +112,10 @@ def test_subtract_different_shapes() -> None:
143112
wrapper.sub(lhs, rhs)
144113

145114

146-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
115+
@pytest.mark.parametrize("dtype_name", get_all_types())
147116
def test_subtract_supported_dtypes(dtype_name: dtype.Dtype) -> None:
148117
"""Test subtraction operation across all supported data types."""
118+
check_type_supported(dtype_name)
149119
shape = (5, 5)
150120
lhs = wrapper.randu(shape, dtype_name)
151121
rhs = wrapper.randu(shape, dtype_name)

0 commit comments

Comments
 (0)