Skip to content

Commit

Permalink
fix: Fix performance regression for eager join_where (#21308)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Feb 18, 2025
1 parent a438cc3 commit e0a3bb5
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 29 deletions.
3 changes: 0 additions & 3 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,6 @@ pub fn resolve_join(
let mut as_with_columns_l = vec![];
let mut as_with_columns_r = vec![];
for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {
//polars_ensure!(!lnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'");
//polars_ensure!(!rnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'");

let ltype = get_dtype!(lnode, &schema_left)?;
let rtype = get_dtype!(rnode, &schema_right)?;

Expand Down
77 changes: 53 additions & 24 deletions crates/polars-plan/src/plans/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,26 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
#[cfg(debug_assertions)]
let prev_schema = lp_arena.get(lp_top).schema(lp_arena).into_owned();

// Collect members for optimizations that need it.
let mut members = MemberCollector::new();
if !opt_flags.eager() && (comm_subexpr_elim || opt_flags.projection_pushdown()) {
members.collect(lp_top, lp_arena, expr_arena)
let mut _opt_members = &mut None;

macro_rules! get_or_init_members {
() => {
_get_or_init_members(_opt_members, lp_top, lp_arena, expr_arena)
};
}

macro_rules! get_members_opt {
() => {
_opt_members.as_mut()
};
}

// Run before slice pushdown
if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE)
&& members.has_group_by | members.has_sort | members.has_distinct
{
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE) {
let members = get_or_init_members!();
if members.has_group_by | members.has_sort | members.has_distinct {
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
}
}

if opt_flags.simplify_expr() {
Expand All @@ -135,21 +144,24 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
}

#[cfg(feature = "cse")]
let _cse_plan_changed = if comm_subplan_elim
&& members.has_joins_or_unions
&& members.has_duplicate_scans()
&& !members.has_cache
{
if verbose {
eprintln!("found multiple sources; run comm_subplan_elim")
}
let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena);
let _cse_plan_changed = if comm_subplan_elim {
let members = get_or_init_members!();

if members.has_joins_or_unions && members.has_duplicate_scans() && !members.has_cache {
if verbose {
eprintln!("found multiple sources; run comm_subplan_elim")
}

prune_unused_caches(lp_arena, cid2c);
let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena);

lp_top = lp;
members.has_cache |= changed;
changed
prune_unused_caches(lp_arena, cid2c);

lp_top = lp;
members.has_cache |= changed;
changed
} else {
false
}
} else {
false
};
Expand Down Expand Up @@ -181,7 +193,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/
}

// Make sure it is after predicate pushdown
if opt_flags.collapse_joins() && members.has_filter_with_join_input {
if opt_flags.collapse_joins() && get_or_init_members!().has_filter_with_join_input {
collapse_joins::optimize(lp_top, lp_arena, expr_arena);
}

Expand Down Expand Up @@ -219,7 +231,10 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/

lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?;

if members.has_joins_or_unions && members.has_cache && _cse_plan_changed {
if _cse_plan_changed
&& get_members_opt!()
.is_some_and(|members| members.has_joins_or_unions && members.has_cache)
{
// We only want to run this on cse inserted caches
cache_states::set_cache_states(
lp_top,
Expand All @@ -234,7 +249,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/

// This one should run (nearly) last as this modifies the projections
#[cfg(feature = "cse")]
if comm_subexpr_elim && !members.has_ext_context {
if comm_subexpr_elim && !get_or_init_members!().has_ext_context {
let mut optimizer = CommonSubExprOptimizer::new();
let alp_node = IRNode::new_mutate(lp_top);

Expand All @@ -260,3 +275,17 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/

Ok(lp_top)
}

fn _get_or_init_members<'a>(
opt_members: &'a mut Option<MemberCollector>,
lp_top: Node,
lp_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
) -> &'a mut MemberCollector {
opt_members.get_or_insert_with(|| {
let mut members = MemberCollector::new();
members.collect(lp_top, lp_arena, expr_arena);

members
})
}
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def test_join_where_literal_20061() -> None:
df_right,
pl.col("value_left") > pl.col("value_right"),
pl.col("flag_right") == pl.lit(1, dtype=pl.Int8),
).sort("id").to_dict(as_series=False) == {
).sort(pl.all()).to_dict(as_series=False) == {
"id": [1, 2, 3, 3],
"value_left": [10, 20, 30, 30],
"flag": [1, 0, 1, 1],
Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import typing
import warnings
from datetime import date, datetime
from typing import TYPE_CHECKING, Literal
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1738,3 +1739,27 @@ def test_empty_join_result_with_array_15474() -> None:
result = lhs.join(rhs, on="x")
expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)})
assert_frame_equal(result, expected)


@pytest.mark.slow
def test_join_where_eager_perf_21145() -> None:
left = pl.Series("left", range(3_000)).to_frame()
right = pl.Series("right", range(1_000)).to_frame()

def time_func(func: Callable[[], Any]) -> float:
times = []
for _ in range(3):
t = perf_counter()
func()
times.append(perf_counter() - t)

return min(times)

p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right"))
runtime_eager = time_func(lambda: left.join_where(right, p))
runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect())
runtime_ratio = runtime_eager / runtime_lazy

if runtime_ratio > 1.3:
msg = f"runtime_ratio ({runtime_ratio}) > 1.3x ({runtime_eager = }, {runtime_lazy = })"
raise ValueError(msg)

0 comments on commit e0a3bb5

Please sign in to comment.