Skip to content

Commit

Permalink
fix: Fix incorrect predicate pushdown for predicates referring to rig…
Browse files Browse the repository at this point in the history
…ht-join key columns (#21293)
  • Loading branch information
nameexhaustion authored Feb 17, 2025
1 parent 8978e18 commit a438cc3
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 158 deletions.
5 changes: 2 additions & 3 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
#[cfg(not(feature = "chunked_ids"))]
pub type ChunkJoinIds = Vec<IdxSize>;

use once_cell::sync::Lazy;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
Expand Down Expand Up @@ -138,8 +137,8 @@ impl JoinArgs {
}

pub fn suffix(&self) -> &PlSmallStr {
static DEFAULT: Lazy<PlSmallStr> = Lazy::new(|| PlSmallStr::from_static("_right"));
self.suffix.as_ref().unwrap_or(&*DEFAULT)
const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
self.suffix.as_ref().unwrap_or(DEFAULT)
}
}

Expand Down
91 changes: 12 additions & 79 deletions crates/polars-plan/src/plans/optimizer/collapse_joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,84 +10,11 @@ use polars_core::schema::*;
use polars_ops::frame::{IEJoinOptions, InequalityOperator};
use polars_ops::frame::{JoinCoalesce, JoinType, MaintainOrderJoin};
use polars_utils::arena::{Arena, Node};
use polars_utils::pl_str::PlSmallStr;

use super::{aexpr_to_leaf_names_iter, AExpr, ExprOrigin, JoinOptions, IR};
use crate::dsl::{JoinTypeOptionsIR, Operator};
use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker};
use crate::plans::{ExprIR, MintermIter, OutputName};

fn remove_suffix<'a>(
exprs: &mut Vec<ExprIR>,
expr_arena: &mut Arena<AExpr>,
schema: &'a SchemaRef,
suffix: &'a str,
) {
let mut remover = RemoveSuffix {
schema: schema.as_ref(),
suffix,
};

for expr in exprs {
// Using AexprNode::rewrite() ensures we do not mutate any nodes in-place. The nodes may be
// used in other locations and mutating them will cause really confusing bugs, such as
// https://github.com/pola-rs/polars/issues/20831.
match AexprNode::new(expr.node()).rewrite(&mut remover, expr_arena) {
Ok(v) => {
expr.set_node(v.node());

if let OutputName::ColumnLhs(colname) = expr.output_name_inner() {
if colname.ends_with(suffix) && !schema.contains(colname.as_str()) {
let name = PlSmallStr::from(&colname[..colname.len() - suffix.len()]);
expr.set_columnlhs(name);
}
}
},
e @ Err(_) => panic!("should not have failed: {:?}", e),
}
}
}

struct RemoveSuffix<'a> {
schema: &'a Schema,
suffix: &'a str,
}

impl RewritingVisitor for RemoveSuffix<'_> {
type Node = AexprNode;
type Arena = Arena<AExpr>;

fn pre_visit(
&mut self,
node: &Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<crate::prelude::visitor::RewriteRecursion> {
let AExpr::Column(colname) = arena.get(node.node()) else {
return Ok(RewriteRecursion::NoMutateAndContinue);
};

if !colname.ends_with(self.suffix) || self.schema.contains(colname.as_str()) {
return Ok(RewriteRecursion::NoMutateAndContinue);
}

Ok(RewriteRecursion::MutateAndContinue)
}

fn mutate(
&mut self,
node: Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<Self::Node> {
let AExpr::Column(colname) = arena.get(node.node()) else {
unreachable!();
};

// Safety: Checked in pre_visit()
Ok(AexprNode::new(arena.add(AExpr::Column(PlSmallStr::from(
&colname[..colname.len() - self.suffix.len()],
)))))
}
}
use crate::plans::optimizer::join_utils::remove_suffix;
use crate::plans::{ExprIR, MintermIter};

fn and_expr(left: Node, right: Node, expr_arena: &mut Arena<AExpr>) -> Node {
expr_arena.add(AExpr::BinaryExpr {
Expand Down Expand Up @@ -195,14 +122,16 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &mut Arena<AEx
left_schema,
right_schema,
suffix.as_str(),
);
)
.unwrap();
let right_origin = ExprOrigin::get_expr_origin(
right,
expr_arena,
left_schema,
right_schema,
suffix.as_str(),
);
)
.unwrap();

use ExprOrigin as EO;

Expand Down Expand Up @@ -282,12 +211,16 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &mut Arena<AEx
let mut can_simplify_join = false;

if !eq_left_on.is_empty() {
remove_suffix(&mut eq_right_on, expr_arena, right_schema, suffix.as_str());
for expr in eq_right_on.iter_mut() {
remove_suffix(expr, expr_arena, right_schema, suffix.as_str());
}
can_simplify_join = true;
} else {
#[cfg(feature = "iejoin")]
if !ie_op.is_empty() {
remove_suffix(&mut ie_right_on, expr_arena, right_schema, suffix.as_str());
for expr in ie_right_on.iter_mut() {
remove_suffix(expr, expr_arena, right_schema, suffix.as_str());
}
can_simplify_join = true;
}
can_simplify_join |= options.args.how.is_cross();
Expand Down
145 changes: 109 additions & 36 deletions crates/polars-plan/src/plans/optimizer/join_utils.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use polars_core::error::{polars_bail, PolarsResult};
use polars_core::schema::*;
use polars_utils::arena::{Arena, Node};
use polars_utils::pl_str::PlSmallStr;

use super::{aexpr_to_leaf_names_iter, AExpr};
use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker};
use crate::plans::{ExprIR, OutputName};

/// Join origin of an expression
#[derive(Debug, Clone, PartialEq, Copy)]
#[repr(u8)]
pub(crate) enum ExprOrigin {
// Note: There is a merge() function implemented on this enum that relies
// on this exact u8 repr layout.
// Note: BitOr is implemented on this struct that relies on this exact u8
// repr layout (i.e. treated as a bitfield).
//
/// Utilizes no columns
None = 0b00,
Expand All @@ -21,52 +25,121 @@ pub(crate) enum ExprOrigin {
}

impl ExprOrigin {
/// Errors with ColumnNotFound if a column cannot be found on either side.
pub(crate) fn get_expr_origin(
root: Node,
expr_arena: &Arena<AExpr>,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
left_schema: &Schema,
right_schema: &Schema,
suffix: &str,
) -> ExprOrigin {
let mut expr_origin = ExprOrigin::None;

for name in aexpr_to_leaf_names_iter(root, expr_arena) {
let in_left = left_schema.contains(name.as_str());
let in_right = right_schema.contains(name.as_str());
let has_suffix = name.as_str().ends_with(suffix);
let in_right = in_right
| (has_suffix
&& right_schema.contains(&name.as_str()[..name.len() - suffix.len()]));

let name_origin = match (in_left, in_right, has_suffix) {
(true, false, _) | (true, true, false) => ExprOrigin::Left,
(false, true, _) | (true, true, true) => ExprOrigin::Right,
(false, false, _) => {
unreachable!("Invalid filter column should have been filtered before")
},
};

use ExprOrigin as O;
expr_origin = match (expr_origin, name_origin) {
(O::None, other) | (other, O::None) => other,
(O::Left, O::Left) => O::Left,
(O::Right, O::Right) => O::Right,
_ => O::Both,
};
}
) -> PolarsResult<ExprOrigin> {
aexpr_to_leaf_names_iter(root, expr_arena).try_fold(
ExprOrigin::None,
|acc_origin, column_name| {
Ok(acc_origin
| Self::get_column_origin(&column_name, left_schema, right_schema, suffix)?)
},
)
}

expr_origin
/// Errors with ColumnNotFound if a column cannot be found on either side.
pub(crate) fn get_column_origin(
column_name: &str,
left_schema: &Schema,
right_schema: &Schema,
suffix: &str,
) -> PolarsResult<ExprOrigin> {
Ok(if left_schema.contains(column_name) {
ExprOrigin::Left
} else if right_schema.contains(column_name)
|| column_name
.strip_suffix(suffix)
.is_some_and(|x| right_schema.contains(x))
{
ExprOrigin::Right
} else {
polars_bail!(ColumnNotFound: "{}", column_name)
})
}
}

/// Logical OR with another [`ExprOrigin`]
fn merge(&mut self, other: Self) {
*self = unsafe { std::mem::transmute::<u8, ExprOrigin>(*self as u8 | other as u8) }
impl std::ops::BitOr for ExprOrigin {
type Output = ExprOrigin;

fn bitor(self, rhs: Self) -> Self::Output {
unsafe { std::mem::transmute::<u8, ExprOrigin>(self as u8 | rhs as u8) }
}
}

impl std::ops::BitOrAssign for ExprOrigin {
fn bitor_assign(&mut self, rhs: Self) {
self.merge(rhs)
*self = *self | rhs;
}
}

pub(super) fn remove_suffix<'a>(
expr: &mut ExprIR,
expr_arena: &mut Arena<AExpr>,
schema_rhs: &'a Schema,
suffix: &'a str,
) {
let schema = schema_rhs;
// Using AexprNode::rewrite() ensures we do not mutate any nodes in-place. The nodes may be
// used in other locations and mutating them will cause really confusing bugs, such as
// https://github.com/pola-rs/polars/issues/20831.
let node = AexprNode::new(expr.node())
.rewrite(&mut RemoveSuffix { schema, suffix }, expr_arena)
.unwrap()
.node();

expr.set_node(node);

if let OutputName::ColumnLhs(colname) = expr.output_name_inner() {
if colname.ends_with(suffix) && !schema.contains(colname.as_str()) {
let name = PlSmallStr::from(&colname[..colname.len() - suffix.len()]);
expr.set_columnlhs(name);
}
}

struct RemoveSuffix<'a> {
schema: &'a Schema,
suffix: &'a str,
}

impl RewritingVisitor for RemoveSuffix<'_> {
type Node = AexprNode;
type Arena = Arena<AExpr>;

fn pre_visit(
&mut self,
node: &Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<crate::prelude::visitor::RewriteRecursion> {
let AExpr::Column(colname) = arena.get(node.node()) else {
return Ok(RewriteRecursion::NoMutateAndContinue);
};

if !colname.ends_with(self.suffix) || self.schema.contains(colname.as_str()) {
return Ok(RewriteRecursion::NoMutateAndContinue);
}

Ok(RewriteRecursion::MutateAndContinue)
}

fn mutate(
&mut self,
node: Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<Self::Node> {
let AExpr::Column(colname) = arena.get(node.node()) else {
unreachable!();
};

// Safety: Checked in pre_visit()
Ok(AexprNode::new(arena.add(AExpr::Column(PlSmallStr::from(
&colname[..colname.len() - self.suffix.len()],
)))))
}
}
}

Expand Down
Loading

0 comments on commit a438cc3

Please sign in to comment.