Skip to content

Commit 8d33de5

Browse files
committed
Use a combination of numpy and NDArray arrays in tests
1 parent 0ce9c30 commit 8d33de5

File tree

1 file changed

+25
-36
lines changed

1 file changed

+25
-36
lines changed

Diff for: tests/ndarray/test_jit.py

+25-36
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,37 @@
1414

1515
###### General expressions
1616

17-
def expr_nojit(a, b, c):
18-
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
19-
20-
@blosc2.jit
21-
def expr_jit(a, b, c):
22-
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
23-
2417
# Define the parameters
2518
test_params = [
2619
((10, 100), (10, 100,), "float32"),
2720
((10, 100), (100,), "float64"), # using broadcasting
2821
]
2922

30-
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
31-
def test_expr(shape, cshape, dtype):
23+
@pytest.fixture(params=test_params)
24+
def sample_data(request):
25+
shape, cshape, dtype = request.param
26+
# The jit decorator can work with any numpy or NDArray params in functions
3227
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
33-
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
28+
b = np.linspace(1, 2, shape[0] * shape[1], dtype=dtype).reshape(shape)
3429
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
30+
return a, b, c, shape, cshape, dtype
3531

32+
def expr_nojit(a, b, c):
33+
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
34+
35+
@blosc2.jit
36+
def expr_jit(a, b, c):
37+
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
38+
39+
def test_expr(sample_data):
40+
a, b, c, shape, cshape, dtype = sample_data
3641
d_jit = expr_jit(a, b, c)
3742
d_nojit = expr_nojit(a, b, c)
38-
3943
np.testing.assert_equal(d_jit[...], d_nojit[...])
4044

4145

42-
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
43-
def test_expr_out(shape, cshape, dtype):
44-
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
45-
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
46-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
46+
def test_expr_out(sample_data):
47+
a, b, c, shape, cshape, dtype = sample_data
4748
d_nojit = expr_nojit(a, b, c)
4849

4950
# Testing jit decorator with an out param
@@ -57,11 +58,8 @@ def expr_jit_out(a, b, c):
5758
np.testing.assert_equal(d_jit[...], d_nojit[...])
5859
np.testing.assert_equal(out[...], d_nojit[...])
5960

60-
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
61-
def test_expr_kwargs(shape, cshape, dtype):
62-
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
63-
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
64-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
61+
def test_expr_kwargs(sample_data):
62+
a, b, c, shape, cshape, dtype = sample_data
6563
d_nojit = expr_nojit(a, b, c)
6664

6765
# Testing jit decorator with kwargs
@@ -87,22 +85,16 @@ def reduc_nojit(a, b, c):
8785
def reduc_jit(a, b, c):
8886
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
8987

90-
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
91-
def test_reduc(shape, cshape, dtype):
92-
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
93-
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
94-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
88+
def test_reduc(sample_data):
89+
a, b, c, shape, cshape, dtype = sample_data
9590

9691
d_jit = reduc_jit(a, b, c)
9792
d_nojit = reduc_nojit(a, b, c)
9893

9994
np.testing.assert_equal(d_jit[...], d_nojit[...])
10095

101-
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
102-
def test_reduc_out(shape, cshape, dtype):
103-
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
104-
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
105-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
96+
def test_reduc_out(sample_data):
97+
a, b, c, shape, cshape, dtype = sample_data
10698
d_nojit = reduc_nojit(a, b, c)
10799

108100
# Testing jit decorator with an out param via the reduction function
@@ -117,11 +109,8 @@ def reduc_jit_out(a, b, c):
117109
np.testing.assert_equal(d_jit[...], d_nojit[...])
118110
np.testing.assert_equal(out[...], d_nojit[...])
119111

120-
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
121-
def test_reduc_kwargs(shape, cshape, dtype):
122-
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
123-
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
124-
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
112+
def test_reduc_kwargs(sample_data):
113+
a, b, c, shape, cshape, dtype = sample_data
125114
d_nojit = reduc_nojit(a, b, c)
126115

127116
# Testing jit decorator with kwargs via an out param in the reduction function

0 commit comments

Comments
 (0)