Skip to content

Commit bcb9f0d

Browse files
committed
Add optimization param to compute functions
1 parent d60bd7e commit bcb9f0d

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

src/blosc2/lazyexpr.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ def fast_eval( # noqa: C901
897897
The output array.
898898
"""
899899
out = kwargs.pop("_output", None)
900+
optimization = kwargs.pop("_optimization", "aggressive")
900901
dtype = kwargs.pop("dtype", None)
901902
where: dict | None = kwargs.pop("_where_args", None)
902903
if isinstance(out, blosc2.NDArray):
@@ -962,12 +963,12 @@ def fast_eval( # noqa: C901
962963
expression(tuple(chunk_operands.values()), result, offset=offset)
963964
else:
964965
if where is None:
965-
result = ne.evaluate(expression, chunk_operands)
966+
result = ne.evaluate(expression, chunk_operands, optimization=optimization)
966967
else:
967968
# Apply the where condition (in result)
968969
if len(where) == 2:
969970
new_expr = f"where({expression}, _where_x, _where_y)"
970-
result = ne.evaluate(new_expr, chunk_operands)
971+
result = ne.evaluate(new_expr, chunk_operands, optimization=optimization)
971972
else:
972973
# We do not support one or zero operands in the fast path yet
973974
raise ValueError("Fast path: the where condition must be a tuple with two elements")
@@ -1068,6 +1069,7 @@ def slices_eval( # noqa: C901
10681069
The output array.
10691070
"""
10701071
out: blosc2.NDArray | None = kwargs.pop("_output", None)
1072+
optimization = kwargs.pop("_optimization", "aggressive")
10711073
chunks = kwargs.get("chunks")
10721074
where: dict | None = kwargs.pop("_where_args", None)
10731075
_indices = kwargs.pop("_indices", False)
@@ -1186,7 +1188,7 @@ def slices_eval( # noqa: C901
11861188
continue
11871189

11881190
if where is None:
1189-
result = ne.evaluate(expression, chunk_operands)
1191+
result = ne.evaluate(expression, chunk_operands, optimization=optimization)
11901192
else:
11911193
# Apply the where condition (in result)
11921194
if len(where) == 2:
@@ -1195,9 +1197,9 @@ def slices_eval( # noqa: C901
11951197
# result = np.where(result, x, y)
11961198
# numexpr is a bit faster than np.where, and we can fuse operations in this case
11971199
new_expr = f"where({expression}, _where_x, _where_y)"
1198-
result = ne.evaluate(new_expr, chunk_operands)
1200+
result = ne.evaluate(new_expr, chunk_operands, optimization=optimization)
11991201
elif len(where) == 1:
1200-
result = ne.evaluate(expression, chunk_operands)
1202+
result = ne.evaluate(expression, chunk_operands, optimization=optimization)
12011203
if _indices or _order:
12021204
# Return indices only makes sense when the where condition is a tuple with one element
12031205
# and result is a boolean array
@@ -1332,6 +1334,7 @@ def reduce_slices( # noqa: C901
13321334
The resulting output array.
13331335
"""
13341336
out = kwargs.pop("_output", None)
1337+
optimization = kwargs.pop("_optimization", "aggressive")
13351338
where: dict | None = kwargs.pop("_where_args", None)
13361339
reduce_op = reduce_args.pop("op")
13371340
axis = reduce_args["axis"]
@@ -1468,14 +1471,14 @@ def reduce_slices( # noqa: C901
14681471
# We don't have an actual expression, so avoid a copy
14691472
result = chunk_operands["o0"]
14701473
else:
1471-
result = ne.evaluate(expression, chunk_operands)
1474+
result = ne.evaluate(expression, chunk_operands, optimization=optimization)
14721475
else:
14731476
# Apply the where condition (in result)
14741477
if len(where) == 2:
14751478
new_expr = f"where({expression}, _where_x, _where_y)"
1476-
result = ne.evaluate(new_expr, chunk_operands)
1479+
result = ne.evaluate(new_expr, chunk_operands, optimization=optimization)
14771480
elif len(where) == 1:
1478-
result = ne.evaluate(expression, chunk_operands)
1481+
result = ne.evaluate(expression, chunk_operands, optimization=optimization)
14791482
x = chunk_operands["_where_x"]
14801483
result = x[result]
14811484
else:
@@ -1579,6 +1582,8 @@ def chunked_eval( # noqa: C901
15791582
Default is False.
15801583
_output: NDArray or np.ndarray, optional
15811584
The output array to store the result.
1585+
_optimization: str, optional
1586+
The optimization level to use. Default is 'aggressive'.
15821587
_where_args: dict, optional
15831588
Additional arguments for conditional evaluation.
15841589
"""
@@ -2399,6 +2404,8 @@ def compute(self, item=None, **kwargs) -> blosc2.NDArray:
23992404
# When NumPy ufuncs are called, the user may add an `out` parameter to kwargs
24002405
if "out" in kwargs:
24012406
kwargs["_output"] = kwargs.pop("out")
2407+
if "optimization" in kwargs:
2408+
kwargs["_optimization"] = kwargs.pop("optimization")
24022409
if hasattr(self, "_output"):
24032410
kwargs["_output"] = self._output
24042411
if hasattr(self, "_where_args"):
@@ -2538,7 +2545,7 @@ def save(self, urlpath=None, **kwargs):
25382545
}
25392546

25402547
@classmethod
2541-
def _new_expr(cls, expression, operands, guess, out=None, where=None):
2548+
def _new_expr(cls, expression, operands, guess, out=None, where=None, optimization="aggressive"):
25422549
# Validate the expression
25432550
validate_expr(expression)
25442551
if guess:
@@ -2578,6 +2585,7 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None):
25782585
new_expr._output = out
25792586
if where is not None:
25802587
new_expr._where_args = where
2588+
new_expr._optimization = optimization
25812589
return new_expr
25822590

25832591

@@ -2864,6 +2872,7 @@ def lazyexpr(
28642872
where: tuple | list | None = None,
28652873
local_dict: dict | None = None,
28662874
global_dict: dict | None = None,
2875+
optimization: str = "aggressive",
28672876
) -> LazyExpr:
28682877
"""
28692878
Get a LazyExpr from an expression.
@@ -2889,6 +2898,10 @@ def lazyexpr(
28892898
global_dict: dict, optional
28902899
The global dictionary to use when looking for operands in the expression.
28912900
If not provided, the global dictionary of the caller will be used.
2901+
optimization: str, optional
2902+
The optimization level to use when evaluating the expression. Possible
2903+
values are "aggressive" and "moderate". The default value is "aggressive".
2904+
This parameter has the same meaning as in `numexpr.evaluate()`.
28922905
28932906
Returns
28942907
-------
@@ -2925,13 +2938,16 @@ def lazyexpr(
29252938
expression.operands.update(operands)
29262939
if out is not None:
29272940
expression._output = out
2941+
expression._optimization = optimization
29282942
if where is not None:
29292943
where_args = {"_where_x": where[0], "_where_y": where[1]}
29302944
expression._where_args = where_args
29312945
return expression
29322946
elif isinstance(expression, blosc2.NDArray):
29332947
operands = {"o0": expression}
2934-
return LazyExpr._new_expr("o0", operands, guess=False, out=out, where=where)
2948+
return LazyExpr._new_expr(
2949+
"o0", operands, guess=False, out=out, where=where, optimization=optimization
2950+
)
29352951

29362952
if operands is None:
29372953
# Try to get operands from variables in the stack
@@ -2948,7 +2964,9 @@ def lazyexpr(
29482964
# _new_expr will take care of the constructor, but needs an empty dict in operands
29492965
operands = {}
29502966

2951-
return LazyExpr._new_expr(expression, operands, guess=True, out=out, where=where)
2967+
return LazyExpr._new_expr(
2968+
expression, operands, guess=True, out=out, where=where, optimization=optimization
2969+
)
29522970

29532971

29542972
def _open_lazyarray(array):
@@ -2985,7 +3003,7 @@ def _open_lazyarray(array):
29853003

29863004

29873005
# Mimim numexpr's evaluate function
2988-
def evaluate(ex, local_dict=None, global_dict=None):
3006+
def evaluate(ex, local_dict=None, global_dict=None, optimization="aggressive"):
29893007
"""
29903008
Evaluate a string expression using the Blosc2 compute engine.
29913009
@@ -3008,6 +3026,10 @@ def evaluate(ex, local_dict=None, global_dict=None):
30083026
global_dict: dict, optional
30093027
The global dictionary to use when looking for operands in the expression.
30103028
If not provided, the global dictionary of the caller will be used.
3029+
optimization: str, optional
3030+
The optimization level to use when evaluating the expression. Possible
3031+
values are "aggressive" and "moderate". The default value is "aggressive".
3032+
This parameter has the same meaning as in `numexpr.evaluate()`.
30113033
30123034
Returns
30133035
-------
@@ -3030,7 +3052,7 @@ def evaluate(ex, local_dict=None, global_dict=None):
30303052
[ 5.515625 8.25 11.765625]
30313053
[16.0625 21.140625 27. ]]
30323054
"""
3033-
lexpr = lazyexpr(ex, local_dict=local_dict, global_dict=global_dict)
3055+
lexpr = lazyexpr(ex, local_dict=local_dict, global_dict=global_dict, optimization=optimization)
30343056
return lexpr[:]
30353057

30363058

0 commit comments

Comments
 (0)