diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index a7aa88058621..f75d3ae23d45 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -243,9 +243,6 @@ pub fn resolve_join( let mut as_with_columns_l = vec![]; let mut as_with_columns_r = vec![]; for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) { - //polars_ensure!(!lnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'"); - //polars_ensure!(!rnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'"); - let ltype = get_dtype!(lnode, &schema_left)?; let rtype = get_dtype!(rnode, &schema_right)?; diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index 2b498245981a..b384df2b3d6e 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -116,17 +116,26 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ #[cfg(debug_assertions)] let prev_schema = lp_arena.get(lp_top).schema(lp_arena).into_owned(); - // Collect members for optimizations that need it. - let mut members = MemberCollector::new(); - if !opt_flags.eager() && (comm_subexpr_elim || opt_flags.projection_pushdown()) { - members.collect(lp_top, lp_arena, expr_arena) + let mut _opt_members = &mut None; + + macro_rules! get_or_init_members { + () => { + _get_or_init_members(_opt_members, lp_top, lp_arena, expr_arena) + }; + } + + macro_rules! get_members_opt { + () => { + _opt_members.as_mut() + }; } // Run before slice pushdown - if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE) - && members.has_group_by | members.has_sort | members.has_distinct - { - set_order_flags(lp_top, lp_arena, expr_arena, scratch); + if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE) { + let members = get_or_init_members!(); + if members.has_group_by | members.has_sort | members.has_distinct { + set_order_flags(lp_top, lp_arena, expr_arena, scratch); + } } if opt_flags.simplify_expr() { @@ -135,21 +144,24 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ } #[cfg(feature = "cse")] - let _cse_plan_changed = if comm_subplan_elim - && members.has_joins_or_unions - && members.has_duplicate_scans() - && !members.has_cache - { - if verbose { - eprintln!("found multiple sources; run comm_subplan_elim") - } - let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); + let _cse_plan_changed = if comm_subplan_elim { + let members = get_or_init_members!(); + + if members.has_joins_or_unions && members.has_duplicate_scans() && !members.has_cache { + if verbose { + eprintln!("found multiple sources; run comm_subplan_elim") + } - prune_unused_caches(lp_arena, cid2c); + let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); - lp_top = lp; - members.has_cache |= changed; - changed + prune_unused_caches(lp_arena, cid2c); + + lp_top = lp; + members.has_cache |= changed; + changed + } else { + false + } } else { false }; @@ -181,7 +193,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ } // Make sure it is after predicate pushdown - if opt_flags.collapse_joins() && members.has_filter_with_join_input { + if opt_flags.collapse_joins() && get_or_init_members!().has_filter_with_join_input { collapse_joins::optimize(lp_top, lp_arena, expr_arena); } @@ -219,7 +231,10 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?; - if members.has_joins_or_unions && members.has_cache && _cse_plan_changed { + if _cse_plan_changed + && get_members_opt!() + .is_some_and(|members| members.has_joins_or_unions && members.has_cache) + { // We only want to run this on cse inserted caches cache_states::set_cache_states( lp_top, @@ -234,7 +249,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ // This one should run (nearly) last as this modifies the projections #[cfg(feature = "cse")] - if comm_subexpr_elim && !members.has_ext_context { + if comm_subexpr_elim && !get_or_init_members!().has_ext_context { let mut optimizer = CommonSubExprOptimizer::new(); let alp_node = IRNode::new_mutate(lp_top); @@ -260,3 +275,17 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ Ok(lp_top) } + +fn _get_or_init_members<'a>( + opt_members: &'a mut Option, + lp_top: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> &'a mut MemberCollector { + opt_members.get_or_insert_with(|| { + let mut members = MemberCollector::new(); + members.collect(lp_top, lp_arena, expr_arena); + + members + }) +} diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 3e2bf381ed89..f0cd214529c1 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -663,7 +663,7 @@ def test_join_where_literal_20061() -> None: df_right, pl.col("value_left") > pl.col("value_right"), pl.col("flag_right") == pl.lit(1, dtype=pl.Int8), - ).sort("id").to_dict(as_series=False) == { + ).sort(pl.all()).to_dict(as_series=False) == { "id": [1, 2, 3, 3], "value_left": [10, 20, 30, 30], "flag": [1, 0, 1, 1], diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 1127c96f4f09..9c57a08f7611 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -3,7 +3,8 @@ import typing import warnings from datetime import date, datetime -from typing import TYPE_CHECKING, Literal +from time import perf_counter +from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np import pandas as pd @@ -1738,3 +1739,27 @@ def test_empty_join_result_with_array_15474() -> None: result = lhs.join(rhs, on="x") expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)}) assert_frame_equal(result, expected) + + +@pytest.mark.slow +def test_join_where_eager_perf_21145() -> None: + left = pl.Series("left", range(3_000)).to_frame() + right = pl.Series("right", range(1_000)).to_frame() + + def time_func(func: Callable[[], Any]) -> float: + times = [] + for _ in range(3): + t = perf_counter() + func() + times.append(perf_counter() - t) + + return min(times) + + p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right")) + runtime_eager = time_func(lambda: left.join_where(right, p)) + runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect()) + runtime_ratio = runtime_eager / runtime_lazy + + if runtime_ratio > 1.3: + msg = f"runtime_ratio ({runtime_ratio}) > 1.3x ({runtime_eager = }, {runtime_lazy = })" + raise ValueError(msg)