Skip to content

Commit e0a3bb5

Browse files
fix: Fix performance regression for eager join_where (#21308)
1 parent a438cc3 commit e0a3bb5

File tree

4 files changed

+80
-29
lines changed

4 files changed

+80
-29
lines changed

crates/polars-plan/src/plans/conversion/join.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,6 @@ pub fn resolve_join(
243243
let mut as_with_columns_l = vec![];
244244
let mut as_with_columns_r = vec![];
245245
for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {
246-
//polars_ensure!(!lnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'");
247-
//polars_ensure!(!rnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'");
248-
249246
let ltype = get_dtype!(lnode, &schema_left)?;
250247
let rtype = get_dtype!(rnode, &schema_right)?;
251248

crates/polars-plan/src/plans/optimizer/mod.rs

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,26 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
116116
#[cfg(debug_assertions)]
117117
let prev_schema = lp_arena.get(lp_top).schema(lp_arena).into_owned();
118118

119-
// Collect members for optimizations that need it.
120-
let mut members = MemberCollector::new();
121-
if !opt_flags.eager() && (comm_subexpr_elim || opt_flags.projection_pushdown()) {
122-
members.collect(lp_top, lp_arena, expr_arena)
119+
let mut _opt_members = &mut None;
120+
121+
macro_rules! get_or_init_members {
122+
() => {
123+
_get_or_init_members(_opt_members, lp_top, lp_arena, expr_arena)
124+
};
125+
}
126+
127+
macro_rules! get_members_opt {
128+
() => {
129+
_opt_members.as_mut()
130+
};
123131
}
124132

125133
// Run before slice pushdown
126-
if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE)
127-
&& members.has_group_by | members.has_sort | members.has_distinct
128-
{
129-
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
134+
if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE) {
135+
let members = get_or_init_members!();
136+
if members.has_group_by | members.has_sort | members.has_distinct {
137+
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
138+
}
130139
}
131140

132141
if opt_flags.simplify_expr() {
@@ -135,21 +144,24 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
135144
}
136145

137146
#[cfg(feature = "cse")]
138-
let _cse_plan_changed = if comm_subplan_elim
139-
&& members.has_joins_or_unions
140-
&& members.has_duplicate_scans()
141-
&& !members.has_cache
142-
{
143-
if verbose {
144-
eprintln!("found multiple sources; run comm_subplan_elim")
145-
}
146-
let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena);
147+
let _cse_plan_changed = if comm_subplan_elim {
148+
let members = get_or_init_members!();
149+
150+
if members.has_joins_or_unions && members.has_duplicate_scans() && !members.has_cache {
151+
if verbose {
152+
eprintln!("found multiple sources; run comm_subplan_elim")
153+
}
147154

148-
prune_unused_caches(lp_arena, cid2c);
155+
let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena);
149156

150-
lp_top = lp;
151-
members.has_cache |= changed;
152-
changed
157+
prune_unused_caches(lp_arena, cid2c);
158+
159+
lp_top = lp;
160+
members.has_cache |= changed;
161+
changed
162+
} else {
163+
false
164+
}
153165
} else {
154166
false
155167
};
@@ -181,7 +193,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
181193
}
182194

183195
// Make sure it is after predicate pushdown
184-
if opt_flags.collapse_joins() && members.has_filter_with_join_input {
196+
if opt_flags.collapse_joins() && get_or_init_members!().has_filter_with_join_input {
185197
collapse_joins::optimize(lp_top, lp_arena, expr_arena);
186198
}
187199

@@ -219,7 +231,10 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
219231

220232
lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?;
221233

222-
if members.has_joins_or_unions && members.has_cache && _cse_plan_changed {
234+
if _cse_plan_changed
235+
&& get_members_opt!()
236+
.is_some_and(|members| members.has_joins_or_unions && members.has_cache)
237+
{
223238
// We only want to run this on cse inserted caches
224239
cache_states::set_cache_states(
225240
lp_top,
@@ -234,7 +249,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
234249

235250
// This one should run (nearly) last as this modifies the projections
236251
#[cfg(feature = "cse")]
237-
if comm_subexpr_elim && !members.has_ext_context {
252+
if comm_subexpr_elim && !get_or_init_members!().has_ext_context {
238253
let mut optimizer = CommonSubExprOptimizer::new();
239254
let alp_node = IRNode::new_mutate(lp_top);
240255

@@ -260,3 +275,17 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
260275

261276
Ok(lp_top)
262277
}
278+
279+
fn _get_or_init_members<'a>(
280+
opt_members: &'a mut Option<MemberCollector>,
281+
lp_top: Node,
282+
lp_arena: &mut Arena<IR>,
283+
expr_arena: &mut Arena<AExpr>,
284+
) -> &'a mut MemberCollector {
285+
opt_members.get_or_insert_with(|| {
286+
let mut members = MemberCollector::new();
287+
members.collect(lp_top, lp_arena, expr_arena);
288+
289+
members
290+
})
291+
}

py-polars/tests/unit/operations/test_inequality_join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def test_join_where_literal_20061() -> None:
663663
df_right,
664664
pl.col("value_left") > pl.col("value_right"),
665665
pl.col("flag_right") == pl.lit(1, dtype=pl.Int8),
666-
).sort("id").to_dict(as_series=False) == {
666+
).sort(pl.all()).to_dict(as_series=False) == {
667667
"id": [1, 2, 3, 3],
668668
"value_left": [10, 20, 30, 30],
669669
"flag": [1, 0, 1, 1],

py-polars/tests/unit/operations/test_join.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import typing
44
import warnings
55
from datetime import date, datetime
6-
from typing import TYPE_CHECKING, Literal
6+
from time import perf_counter
7+
from typing import TYPE_CHECKING, Any, Callable, Literal
78

89
import numpy as np
910
import pandas as pd
@@ -1738,3 +1739,27 @@ def test_empty_join_result_with_array_15474() -> None:
17381739
result = lhs.join(rhs, on="x")
17391740
expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)})
17401741
assert_frame_equal(result, expected)
1742+
1743+
1744+
@pytest.mark.slow
1745+
def test_join_where_eager_perf_21145() -> None:
1746+
left = pl.Series("left", range(3_000)).to_frame()
1747+
right = pl.Series("right", range(1_000)).to_frame()
1748+
1749+
def time_func(func: Callable[[], Any]) -> float:
1750+
times = []
1751+
for _ in range(3):
1752+
t = perf_counter()
1753+
func()
1754+
times.append(perf_counter() - t)
1755+
1756+
return min(times)
1757+
1758+
p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right"))
1759+
runtime_eager = time_func(lambda: left.join_where(right, p))
1760+
runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect())
1761+
runtime_ratio = runtime_eager / runtime_lazy
1762+
1763+
if runtime_ratio > 1.3:
1764+
msg = f"runtime_ratio ({runtime_ratio}) > 1.3x ({runtime_eager = }, {runtime_lazy = })"
1765+
raise ValueError(msg)

0 commit comments

Comments
 (0)