Skip to content

Commit b66ad84

Browse files
committed
Add where= filtered aggregate pushdown for CTable and arrays
1 parent 1179928 commit b66ad84

7 files changed

Lines changed: 343 additions & 34 deletions

File tree

examples/ctable/aggregates.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ class Reading:
5252
print(f"temperature min : {temp.min():.2f}")
5353
print(f"temperature max : {temp.max():.2f}")
5454

55+
# -- filtered aggregate pushdown -------------------------------------------
56+
# Use where= on aggregates to avoid materializing an intermediate filtered view.
57+
# Null sentinels are still skipped automatically.
58+
hot_sensor_3 = temp.sum(where=(t.sensor_id == 3) & (t.temperature > 25.0))
59+
hot_humidity = t["humidity"].mean(where=t.temperature > 25.0)
60+
print("\nFiltered aggregate pushdown:")
61+
print(f"sum temperature for sensor_id == 3 and temperature > 25 : {hot_sensor_3:.2f}")
62+
print(f"mean humidity when temperature > 25 : {hot_humidity:.2f}")
63+
5564
print(f"\nalert any : {t['alert'].any()}")
5665
print(f"alert all : {t['alert'].all()}")
5766

examples/ctable/querying.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ class Sale:
4040
print(f"Sales > $200: {len(high_value)} rows")
4141
print(high_value)
4242

43+
# -- filtered aggregate pushdown -------------------------------------------
44+
# For aggregate queries, pass the predicate directly with where= so Blosc2 can
45+
# avoid materializing the filtered table view.
46+
non_returned_revenue = t.amount.sum(where=~t.returned)
47+
north_revenue = t.amount.sum(where=(t.region == "North") & ~t.returned)
48+
print(f"Revenue for non-returned sales: ${non_returned_revenue:.2f}")
49+
print(f"Revenue for non-returned North sales: ${north_revenue:.2f}")
50+
4351
not_returned = t["not returned"]
4452
print(f"Not returned: {len(not_returned)} rows")
4553

src/blosc2/ctable.py

Lines changed: 183 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,12 @@ def _valid_rows(self):
549549

550550
return (self._table._valid_rows & self._mask).compute()
551551

552+
def _lazy_valid_rows(self):
553+
"""Return this column's visible-row mask without forcing lazy evaluation."""
554+
if self._mask is None:
555+
return self._table._valid_rows
556+
return self._table._valid_rows & self._mask
557+
552558
def __getitem__(self, key: int | slice | list | np.ndarray):
553559
"""Return values for the given logical index.
554560
@@ -1241,7 +1247,74 @@ def _require_kind(self, kinds: str, op: str) -> None:
12411247
# Aggregates
12421248
# ------------------------------------------------------------------
12431249

1244-
def sum(self, dtype=None):
1250+
def _normalize_sum_where(self, where):
1251+
"""Normalize an optional ``sum(where=...)`` predicate to a boolean array/expression."""
1252+
if where is None:
1253+
return None
1254+
if isinstance(where, str):
1255+
self._table._guard_varlen_scalar_expression(where)
1256+
where = blosc2.lazyexpr(where, self._table._where_expression_operands())
1257+
if isinstance(where, np.ndarray) and where.dtype == np.bool_:
1258+
where = blosc2.asarray(where)
1259+
if isinstance(where, Column):
1260+
where = where._raw_col == 1 if where._is_nullable_bool else where._raw_col
1261+
if not (
1262+
isinstance(where, (blosc2.NDArray, blosc2.LazyExpr))
1263+
and getattr(where, "dtype", None) == np.bool_
1264+
):
1265+
raise TypeError(f"Expected boolean blosc2.NDArray or LazyExpr, got {type(where).__name__}")
1266+
return where
1267+
1268+
def _lazy_nonnull_mask(self, where=None):
1269+
"""Build a lazy visible-row mask, optionally intersected with non-null values."""
1270+
raw = self._raw_col
1271+
if not isinstance(raw, (blosc2.NDArray, blosc2.LazyExpr)):
1272+
return NotImplemented
1273+
mask = self._lazy_valid_rows()
1274+
if where is not None:
1275+
mask = mask & where
1276+
nv = self.null_value
1277+
if nv is not None:
1278+
if isinstance(nv, (float, np.floating)) and np.isnan(nv):
1279+
nonnull = ~blosc2.isnan(raw)
1280+
else:
1281+
nonnull = raw != nv
1282+
mask = mask & nonnull
1283+
return mask
1284+
1285+
def _sum_lazy_fastpath(self, acc_dtype, where=None):
1286+
"""Try to compute ``sum`` as a pushed-down lazy masked reduction."""
1287+
if self.is_list or self.is_varlen_scalar or self.dtype is None or self.dtype.kind not in "biufc":
1288+
return NotImplemented
1289+
1290+
raw = self._raw_col
1291+
if not isinstance(raw, (blosc2.NDArray, blosc2.LazyExpr)):
1292+
return NotImplemented
1293+
1294+
# A lazy masked reduction scans the full physical column. For very
1295+
# selective filtered views, the existing iterator can skip all-zero mask
1296+
# chunks and is usually faster. Explicit sum(where=...) is already a
1297+
# direct pushed-down aggregate, so do not apply the density guard there.
1298+
total_rows = len(self._table._valid_rows)
1299+
if (
1300+
where is None
1301+
and self._table.base is not None
1302+
and total_rows
1303+
and self._table._n_rows / total_rows < 0.25
1304+
):
1305+
return NotImplemented
1306+
1307+
mask = self._lazy_nonnull_mask(where=where)
1308+
if mask is NotImplemented:
1309+
return NotImplemented
1310+
1311+
zero = acc_dtype(0)
1312+
try:
1313+
return blosc2.where(mask, raw, zero).sum(dtype=acc_dtype)
1314+
except Exception:
1315+
return NotImplemented
1316+
1317+
def sum(self, dtype=None, *, where=None):
12451318
"""Sum of all live, non-null values.
12461319
12471320
Supported dtypes: bool, int, uint, float, complex.
@@ -1254,9 +1327,32 @@ def sum(self, dtype=None):
12541327
Optional accumulator dtype. When omitted, float columns use
12551328
``np.float64``, complex columns use ``np.complex128``, and integer
12561329
/ bool columns use ``np.int64``.
1330+
where:
1331+
Optional boolean predicate. Only rows where the predicate is true,
1332+
the table row is live, and this column is non-null are included.
1333+
This enables direct filtered aggregate pushdown, avoiding creation
1334+
of an intermediate filtered table view.
1335+
1336+
Examples
1337+
--------
1338+
Sum values matching a predicate without materializing a filtered view::
1339+
1340+
total = t["amount"].sum(where=t.category == 3)
1341+
1342+
Combine several column predicates::
1343+
1344+
total = t.col2.sum(where=(t.col1 < 300) & (t.col2 < 400))
1345+
1346+
Nullable sentinel values are skipped automatically::
1347+
1348+
# Equivalent to summing only live rows where predicate is true and
1349+
# t.col2 is not its configured null sentinel.
1350+
total = t.col2.sum(where=t.col1 < 300)
12571351
"""
12581352
self._require_kind("biufc", "sum")
1259-
self._require_nonempty("sum")
1353+
where = self._normalize_sum_where(where)
1354+
if where is None:
1355+
self._require_nonempty("sum")
12601356
# Use a wide accumulator to reduce overflow risk
12611357
acc_dtype = np.dtype(dtype).type if dtype is not None else None
12621358
if acc_dtype is None:
@@ -1271,23 +1367,68 @@ def sum(self, dtype=None):
12711367
else None
12721368
)
12731369
)
1274-
result = acc_dtype(0)
1275-
for chunk in self._nonnull_chunks():
1276-
result += chunk.sum(dtype=acc_dtype)
1370+
1371+
result = self._sum_lazy_fastpath(acc_dtype, where=where)
1372+
if result is NotImplemented:
1373+
if where is not None:
1374+
return self._table.where(where)[self._col_name].sum(dtype=dtype)
1375+
result = acc_dtype(0)
1376+
for chunk in self._nonnull_chunks():
1377+
result += chunk.sum(dtype=acc_dtype)
1378+
12771379
# Return in the column's natural dtype when it fits, else keep the requested/wide dtype
12781380
if dtype is None and self.dtype.kind in "biu":
12791381
return int(result)
12801382
return result
12811383

1282-
def min(self):
1384+
def _lazy_aggregate_fastpath(self, op: str, *, where=None, dtype=None, ddof: int = 0):
1385+
if self.is_list or self.is_varlen_scalar or self.dtype is None or self.dtype.kind not in "biuf":
1386+
return NotImplemented
1387+
raw = self._raw_col
1388+
if not isinstance(raw, (blosc2.NDArray, blosc2.LazyExpr)):
1389+
return NotImplemented
1390+
mask = self._lazy_nonnull_mask(where=where)
1391+
if mask is NotImplemented:
1392+
return NotImplemented
1393+
try:
1394+
count = None
1395+
if op in {"min", "max"}:
1396+
count = int(mask.where(blosc2.ones(raw.shape, dtype=np.int64), 0).sum(dtype=np.int64))
1397+
if count == 0:
1398+
raise ValueError(f"{op}() called on a column where all values are null.")
1399+
if op == "mean":
1400+
return float(raw.mean(where=mask, dtype=dtype or np.float64))
1401+
if op == "std":
1402+
return float(raw.std(where=mask, dtype=dtype or np.float64, ddof=ddof))
1403+
if op == "min":
1404+
return raw.min(where=mask)
1405+
if op == "max":
1406+
return raw.max(where=mask)
1407+
except ValueError:
1408+
if op in {"mean", "std"}:
1409+
return float("nan")
1410+
raise
1411+
except Exception:
1412+
return NotImplemented
1413+
return NotImplemented
1414+
1415+
def min(self, *, where=None):
12831416
"""Minimum live, non-null value.
12841417
12851418
Supported dtypes: bool, int, uint, float, string, bytes.
12861419
Strings are compared lexicographically.
1287-
Null sentinel values are skipped.
1420+
Null sentinel values are skipped. When *where* is provided, only rows
1421+
matching the boolean predicate are included.
12881422
"""
12891423
self._require_kind("biufUS", "min")
1290-
self._require_nonempty("min")
1424+
where = self._normalize_sum_where(where)
1425+
if where is None:
1426+
self._require_nonempty("min")
1427+
fast = self._lazy_aggregate_fastpath("min", where=where)
1428+
if fast is not NotImplemented:
1429+
return fast
1430+
if where is not None:
1431+
return self._table.where(where)[self._col_name].min()
12911432
result = None
12921433
is_str = self.dtype.kind in "US"
12931434
for chunk in self._nonnull_chunks():
@@ -1300,15 +1441,23 @@ def min(self):
13001441
raise ValueError("min() called on a column where all values are null.")
13011442
return result
13021443

1303-
def max(self):
1444+
def max(self, *, where=None):
13041445
"""Maximum live, non-null value.
13051446
13061447
Supported dtypes: bool, int, uint, float, string, bytes.
13071448
Strings are compared lexicographically.
1308-
Null sentinel values are skipped.
1449+
Null sentinel values are skipped. When *where* is provided, only rows
1450+
matching the boolean predicate are included.
13091451
"""
13101452
self._require_kind("biufUS", "max")
1311-
self._require_nonempty("max")
1453+
where = self._normalize_sum_where(where)
1454+
if where is None:
1455+
self._require_nonempty("max")
1456+
fast = self._lazy_aggregate_fastpath("max", where=where)
1457+
if fast is not NotImplemented:
1458+
return fast
1459+
if where is not None:
1460+
return self._table.where(where)[self._col_name].max()
13121461
result = None
13131462
is_str = self.dtype.kind in "US"
13141463
for chunk in self._nonnull_chunks():
@@ -1319,15 +1468,23 @@ def max(self):
13191468
raise ValueError("max() called on a column where all values are null.")
13201469
return result
13211470

1322-
def mean(self) -> float:
1471+
def mean(self, *, where=None) -> float:
13231472
"""Arithmetic mean of all live, non-null values.
13241473
13251474
Supported dtypes: bool, int, uint, float.
1326-
Null sentinel values are skipped.
1475+
Null sentinel values are skipped. When *where* is provided, only rows
1476+
matching the boolean predicate are included.
13271477
Always returns a Python float.
13281478
"""
13291479
self._require_kind("biuf", "mean")
1330-
self._require_nonempty("mean")
1480+
where = self._normalize_sum_where(where)
1481+
if where is None:
1482+
self._require_nonempty("mean")
1483+
fast = self._lazy_aggregate_fastpath("mean", where=where)
1484+
if fast is not NotImplemented:
1485+
return fast
1486+
if where is not None:
1487+
return self._table.where(where)[self._col_name].mean()
13311488
total = np.float64(0)
13321489
count = 0
13331490
for chunk in self._nonnull_chunks():
@@ -1337,21 +1494,31 @@ def mean(self) -> float:
13371494
return float("nan")
13381495
return float(total / count)
13391496

1340-
def std(self, ddof: int = 0) -> float:
1497+
def std(self, ddof: int = 0, *, where=None) -> float:
13411498
"""Standard deviation of all live, non-null values (single-pass, Welford's algorithm).
13421499
13431500
Parameters
13441501
----------
13451502
ddof:
13461503
Delta degrees of freedom. ``0`` (default) gives the population
13471504
std; ``1`` gives the sample std (divides by N-1).
1505+
where:
1506+
Optional boolean predicate. Only rows where the predicate is true,
1507+
the table row is live, and this column is non-null are included.
13481508
13491509
Supported dtypes: bool, int, uint, float.
13501510
Null sentinel values are skipped.
13511511
Always returns a Python float.
13521512
"""
13531513
self._require_kind("biuf", "std")
1354-
self._require_nonempty("std")
1514+
where = self._normalize_sum_where(where)
1515+
if where is None:
1516+
self._require_nonempty("std")
1517+
fast = self._lazy_aggregate_fastpath("std", where=where, ddof=ddof)
1518+
if fast is not NotImplemented:
1519+
return fast
1520+
if where is not None:
1521+
return self._table.where(where)[self._col_name].std(ddof=ddof)
13551522

13561523
# Chan's parallel update — combines per-chunk (n, mean, M2) tuples.
13571524
# This is numerically stable and requires only a single pass.

0 commit comments

Comments
 (0)