Skip to content

Commit e4c40e8

Browse files
committed
Support out in mean, std and var reductions too
1 parent 9c32763 commit e4c40e8

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

Diff for: src/blosc2/lazyexpr.py

+12
Original file line numberDiff line numberDiff line change
@@ -2188,6 +2188,10 @@ def mean(self, axis=None, dtype=None, keepdims=False, **kwargs):
21882188
if num_elements == 0:
21892189
raise ValueError("mean of an empty array is not defined")
21902190
out = total_sum / num_elements
2191+
out2 = kwargs.pop("out", None)
2192+
if out2 is not None:
2193+
out2[:] = out
2194+
return out2
21912195
if kwargs != {} and not np.isscalar(out):
21922196
out = blosc2.asarray(out, **kwargs)
21932197
return out
@@ -2202,6 +2206,10 @@ def std(self, axis=None, dtype=None, keepdims=False, ddof=0, **kwargs):
22022206
out = np.sqrt(out * num_elements / (num_elements - ddof))
22032207
else:
22042208
out = np.sqrt(out)
2209+
out2 = kwargs.pop("out", None)
2210+
if out2 is not None:
2211+
out2[:] = out
2212+
return out2
22052213
if kwargs != {} and not np.isscalar(out):
22062214
out = blosc2.asarray(out, **kwargs)
22072215
return out
@@ -2216,6 +2224,10 @@ def var(self, axis=None, dtype=None, keepdims=False, ddof=0, **kwargs):
22162224
out = out * num_elements / (num_elements - ddof)
22172225
else:
22182226
out = expr.mean(axis=axis, dtype=dtype, keepdims=keepdims, item=item)
2227+
out2 = kwargs.pop("out", None)
2228+
if out2 is not None:
2229+
out2[:] = out
2230+
return out2
22192231
if kwargs != {} and not np.isscalar(out):
22202232
out = blosc2.asarray(out, **kwargs)
22212233
return out

Diff for: tests/ndarray/test_jit.py

+43
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def reduc_nojit(a, b, c):
9494
return np.sum(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
9595

9696

97+
def reduc_mean_nojit(a, b, c):
98+
return np.mean(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
99+
100+
101+
def reduc_std_nojit(a, b, c):
102+
return np.std(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
103+
104+
97105
@blosc2.jit
98106
def reduc_jit(a, b, c):
99107
return np.sum(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=1)
@@ -125,6 +133,22 @@ def reduc_jit_out(a, b, c):
125133
np.testing.assert_equal(out[...], d_nojit[...])
126134

127135

136+
def test_reduc_mean_out(sample_data):
137+
a, b, c, shape, cshape, dtype = sample_data
138+
d_nojit = reduc_mean_nojit(a, b, c)
139+
140+
# Testing jit decorator with an out param via the reduction function
141+
out = np.zeros((shape[0],), dtype=np.float64)
142+
143+
# Note that out does not work with reductions as the last function call
144+
@blosc2.jit
145+
def reduc_mean_jit_out(a, b, c):
146+
return np.mean(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=1, out=out)
147+
148+
d_jit = reduc_mean_jit_out(a, b, c)
149+
np.testing.assert_equal(out[...], d_nojit[...])
150+
151+
128152
def test_reduc_kwargs(sample_data):
129153
a, b, c, shape, cshape, dtype = sample_data
130154
d_nojit = reduc_nojit(a, b, c)
@@ -142,3 +166,22 @@ def reduc_jit_cparams(a, b, c):
142166
assert d_jit.schunk.cparams.clevel == 1
143167
assert d_jit.schunk.cparams.codec == blosc2.Codec.LZ4
144168
assert d_jit.schunk.cparams.filters == [blosc2.Filter.BITSHUFFLE] + [blosc2.Filter.NOFILTER] * 5
169+
170+
171+
def test_reduc_std_kwargs(sample_data):
172+
a, b, c, shape, cshape, dtype = sample_data
173+
d_nojit = reduc_std_nojit(a, b, c)
174+
175+
# Testing jit decorator with kwargs via an out param in the reduction function
176+
cparams = blosc2.CParams(clevel=1, codec=blosc2.Codec.LZ4, filters=[blosc2.Filter.BITSHUFFLE])
177+
out = blosc2.zeros((shape[0],), dtype=np.float64, cparams=cparams)
178+
179+
@blosc2.jit
180+
def reduc_std_jit_cparams(a, b, c):
181+
return np.std(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=1, out=out)
182+
183+
d_jit = reduc_std_jit_cparams(a, b, c)
184+
np.testing.assert_equal(d_jit[...], d_nojit[...])
185+
assert d_jit.schunk.cparams.clevel == 1
186+
assert d_jit.schunk.cparams.codec == blosc2.Codec.LZ4
187+
assert d_jit.schunk.cparams.filters == [blosc2.Filter.BITSHUFFLE] + [blosc2.Filter.NOFILTER] * 5

0 commit comments

Comments
 (0)