Skip to content

Commit 77cc75e

Browse files
committed
Add more tests for jit decorator, and fix some issues
1 parent 485c462 commit 77cc75e

File tree

3 files changed

+149
-9
lines changed

3 files changed

+149
-9
lines changed

src/blosc2/lazyexpr.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def extract_numpy_scalars(expr: str):
587587
return transformed_expr, transformer.replacements
588588

589589

590-
def validate_inputs(inputs: dict, out=None) -> tuple: # noqa: C901
590+
def validate_inputs(inputs: dict, out=None, reduce=False) -> tuple: # noqa: C901
591591
"""Validate the inputs for the expression."""
592592
if len(inputs) == 0:
593593
if out is None:
@@ -624,7 +624,7 @@ def validate_inputs(inputs: dict, out=None) -> tuple: # noqa: C901
624624
fast_path = True
625625
first_input = NDinputs[0]
626626
# Check the out NDArray (if present) first
627-
if isinstance(out, blosc2.NDArray):
627+
if isinstance(out, blosc2.NDArray) and not reduce:
628628
if first_input.shape != out.shape:
629629
raise ValueError("Output shape does not match the first input shape")
630630
if first_input.chunks != out.chunks:
@@ -1590,14 +1590,15 @@ def chunked_eval( # noqa: C901
15901590
if where:
15911591
# Make the where arguments part of the operands
15921592
operands = {**operands, **where}
1593-
_, _, _, fast_path = validate_inputs(operands, out)
1593+
1594+
reduce_args = kwargs.pop("_reduce_args", {})
1595+
_, _, _, fast_path = validate_inputs(operands, out, reduce=reduce_args != {})
15941596

15951597
# Activate last read cache for NDField instances
15961598
for op in operands:
15971599
if isinstance(operands[op], blosc2.NDField):
15981600
operands[op].ndarr.keep_last_read = True
15991601

1600-
reduce_args = kwargs.pop("_reduce_args", {})
16011602
if reduce_args:
16021603
# Eval and reduce the expression in a single step
16031604
return reduce_slices(expression, operands, reduce_args=reduce_args, _slice=item, **kwargs)

src/blosc2/proxy.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from abc import ABC, abstractmethod
99

1010
import numpy as np
11-
from traitlets import Callable
1211

1312
import blosc2
1413

@@ -581,7 +580,7 @@ def __getitem__(self, item: slice | list[slice]) -> np.ndarray:
581580
return self.src[item]
582581

583582

584-
def jit(func : Callable, out=None, **kwargs): # noqa: C901
583+
def jit(func=None, *, out=None, **kwargs): # noqa: C901
585584
"""
586585
Prepare a function so that it can be used with the Blosc2 compute engine.
587586
@@ -610,9 +609,10 @@ def jit(func : Callable, out=None, **kwargs): # noqa: C901
610609
-----
611610
* Although many NumPy functions are supported, some may not be implemented yet.
612611
If you find a function that is not supported, please open an issue.
613-
* `kwargs` parameters are not supported for all expressions (e.g. when using a
614-
reduction as the last function). In this case, you can still use the `out`
615-
parameter of the reduction function for some custom control over the output.
612+
* `out` and `kwargs` parameters are not supported for all expressions
613+
(e.g. when using a reduction as the last function). In this case, you can
614+
still use the `out` parameter of the reduction function for some custom
615+
control over the output.
616616
617617
Examples
618618
--------

tests/ndarray/test_jit.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#######################################################################
2+
# Copyright (c) 2019-present, Blosc Development Team <[email protected]>
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under a BSD-style license (found in the
6+
# LICENSE file in the root directory of this source tree)
7+
#######################################################################
8+
9+
import pytest
10+
11+
import blosc2
12+
13+
import numpy as np
14+
15+
###### General expressions
16+
17+
def expr_nojit(a, b, c):
18+
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
19+
20+
@blosc2.jit
21+
def expr_jit(a, b, c):
22+
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
23+
24+
# Define the parameters
25+
test_params = [
26+
((10, 100), (10, 100,), "float32"),
27+
((10, 100), (100,), "float64"), # using broadcasting
28+
]
29+
30+
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
31+
def test_expr(shape, cshape, dtype):
32+
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
33+
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
34+
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
35+
36+
d_jit = expr_jit(a, b, c)
37+
d_nojit = expr_nojit(a, b, c)
38+
39+
np.testing.assert_equal(d_jit[...], d_nojit[...])
40+
41+
42+
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
43+
def test_expr_out(shape, cshape, dtype):
44+
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
45+
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
46+
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
47+
d_nojit = expr_nojit(a, b, c)
48+
49+
# Testing jit decorator with an out param
50+
out = blosc2.zeros(shape, dtype=np.bool_)
51+
52+
@blosc2.jit(out=out)
53+
def expr_jit_out(a, b, c):
54+
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
55+
56+
d_jit = expr_jit_out(a, b, c)
57+
np.testing.assert_equal(d_jit[...], d_nojit[...])
58+
np.testing.assert_equal(out[...], d_nojit[...])
59+
60+
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
61+
def test_expr_kwargs(shape, cshape, dtype):
62+
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
63+
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
64+
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
65+
d_nojit = expr_nojit(a, b, c)
66+
67+
# Testing jit decorator with kwargs
68+
cparams = blosc2.CParams(clevel=1, codec=blosc2.Codec.LZ4, filters=[blosc2.Filter.BITSHUFFLE])
69+
70+
@blosc2.jit(**{"cparams": cparams})
71+
def expr_jit_cparams(a, b, c):
72+
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)
73+
74+
d_jit = expr_jit_cparams(a, b, c)
75+
np.testing.assert_equal(d_jit[...], d_nojit[...])
76+
assert d_jit.schunk.cparams.clevel == 1
77+
assert d_jit.schunk.cparams.codec == blosc2.Codec.LZ4
78+
assert d_jit.schunk.cparams.filters == [blosc2.Filter.BITSHUFFLE] + [blosc2.Filter.NOFILTER] * 5
79+
80+
81+
###### Reductions
82+
83+
def reduc_nojit(a, b, c):
84+
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
85+
86+
@blosc2.jit
87+
def reduc_jit(a, b, c):
88+
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
89+
90+
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
91+
def test_reduc(shape, cshape, dtype):
92+
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
93+
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
94+
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
95+
96+
d_jit = reduc_jit(a, b, c)
97+
d_nojit = reduc_nojit(a, b, c)
98+
99+
np.testing.assert_equal(d_jit[...], d_nojit[...])
100+
101+
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
102+
def test_reduc_out(shape, cshape, dtype):
103+
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
104+
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
105+
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
106+
d_nojit = reduc_nojit(a, b, c)
107+
108+
# Testing jit decorator with an out param via the reduction function
109+
out = np.zeros((shape[0],), dtype=np.int64)
110+
111+
# Note that out does not work with reductions as the last function call
112+
@blosc2.jit
113+
def reduc_jit_out(a, b, c):
114+
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1, out=out)
115+
116+
d_jit = reduc_jit_out(a, b, c)
117+
np.testing.assert_equal(d_jit[...], d_nojit[...])
118+
np.testing.assert_equal(out[...], d_nojit[...])
119+
120+
@pytest.mark.parametrize("shape, cshape, dtype", test_params)
121+
def test_reduc_kwargs(shape, cshape, dtype):
122+
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
123+
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
124+
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
125+
d_nojit = reduc_nojit(a, b, c)
126+
127+
# Testing jit decorator with kwargs via an out param in the reduction function
128+
cparams = blosc2.CParams(clevel=1, codec=blosc2.Codec.LZ4, filters=[blosc2.Filter.BITSHUFFLE])
129+
out = blosc2.zeros((shape[0],), dtype=np.int64, cparams=cparams)
130+
131+
@blosc2.jit
132+
def reduc_jit_cparams(a, b, c):
133+
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1, out=out)
134+
135+
d_jit = reduc_jit_cparams(a, b, c)
136+
np.testing.assert_equal(d_jit[...], d_nojit[...])
137+
assert d_jit.schunk.cparams.clevel == 1
138+
assert d_jit.schunk.cparams.codec == blosc2.Codec.LZ4
139+
assert d_jit.schunk.cparams.filters == [blosc2.Filter.BITSHUFFLE] + [blosc2.Filter.NOFILTER] * 5

0 commit comments

Comments
 (0)