|
1 | 1 | import random
|
2 | 2 |
|
3 |
| -# import numpy as np |
4 | 3 | import pytest
|
5 | 4 |
|
6 | 5 | import arrayfire_wrapper.dtypes as dtype
|
7 | 6 | import arrayfire_wrapper.lib as wrapper
|
8 |
| - |
9 |
| -# from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string |
| 7 | +from tests.utility_functions import check_type_supported, get_all_types |
10 | 8 |
|
11 | 9 |
|
12 | 10 | @pytest.mark.parametrize(
|
@@ -41,39 +39,10 @@ def test_add_different_shapes() -> None:
|
41 | 39 | wrapper.add(lhs, rhs)
|
42 | 40 |
|
43 | 41 |
|
44 |
| -dtype_map = { |
45 |
| - "int16": dtype.s16, |
46 |
| - "int32": dtype.s32, |
47 |
| - "int64": dtype.s64, |
48 |
| - "uint8": dtype.u8, |
49 |
| - "uint16": dtype.u16, |
50 |
| - "uint32": dtype.u32, |
51 |
| - "uint64": dtype.u64, |
52 |
| - "float16": dtype.f16, |
53 |
| - "float32": dtype.f32, |
54 |
| - # 'float64': dtype.f64, |
55 |
| - # 'complex64': dtype.c64, |
56 |
| - "complex32": dtype.c32, |
57 |
| - "bool": dtype.b8, |
58 |
| - "s16": dtype.s16, |
59 |
| - "s32": dtype.s32, |
60 |
| - "s64": dtype.s64, |
61 |
| - "u8": dtype.u8, |
62 |
| - "u16": dtype.u16, |
63 |
| - "u32": dtype.u32, |
64 |
| - "u64": dtype.u64, |
65 |
| - "f16": dtype.f16, |
66 |
| - "f32": dtype.f32, |
67 |
| - # 'f64': dtype.f64, |
68 |
| - "c32": dtype.c32, |
69 |
| - # 'c64': dtype.c64, |
70 |
| - "b8": dtype.b8, |
71 |
| -} |
72 |
| - |
73 |
| - |
74 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 42 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
75 | 43 | def test_add_supported_dtypes(dtype_name: dtype.Dtype) -> None:
|
76 | 44 | """Test addition operation across all supported data types."""
|
| 45 | + check_type_supported(dtype_name) |
77 | 46 | shape = (5, 5) # Using a common shape for simplicity
|
78 | 47 | lhs = wrapper.randu(shape, dtype_name)
|
79 | 48 | rhs = wrapper.randu(shape, dtype_name)
|
@@ -143,9 +112,10 @@ def test_subtract_different_shapes() -> None:
|
143 | 112 | wrapper.sub(lhs, rhs)
|
144 | 113 |
|
145 | 114 |
|
146 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 115 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
147 | 116 | def test_subtract_supported_dtypes(dtype_name: dtype.Dtype) -> None:
|
148 | 117 | """Test subtraction operation across all supported data types."""
|
| 118 | + check_type_supported(dtype_name) |
149 | 119 | shape = (5, 5)
|
150 | 120 | lhs = wrapper.randu(shape, dtype_name)
|
151 | 121 | rhs = wrapper.randu(shape, dtype_name)
|
|
0 commit comments