Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
support left outer and left mark join for hash join rule
Browse files Browse the repository at this point in the history
Not ideal, wants to unite inner, left-outer, and left-mark into one rule

Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Dec 22, 2024
1 parent d1e27a5 commit cecc8a3
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 7 deletions.
2 changes: 2 additions & 0 deletions optd-datafusion-bridge/src/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ impl OptdPlanContext<'_> {
let right_exec = self.conv_from_optd_plan_node(node.right(), meta).await?;
let join_type = match node.join_type() {
JoinType::Inner => datafusion::logical_expr::JoinType::Inner,
JoinType::LeftOuter => datafusion::logical_expr::JoinType::Left,
JoinType::LeftMark => datafusion::logical_expr::JoinType::LeftMark,
_ => unimplemented!(),
};
let left_exprs = node.left_keys().to_vec();
Expand Down
6 changes: 4 additions & 2 deletions optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ impl DatafusionOptimizer {
rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new()));
rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.push(Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down Expand Up @@ -177,7 +179,7 @@ impl DatafusionOptimizer {
for rule in rules {
rule_wrappers.push(rule);
}
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down
239 changes: 234 additions & 5 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,20 @@ fn apply_join_assoc(
}

define_impl_rule!(
HashJoinRule,
apply_hash_join,
HashJoinInnerRule,
apply_hash_join_inner,
(Join(JoinType::Inner), left, right)
);

fn apply_hash_join(
fn apply_hash_join_inner(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding).unwrap();
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
Expand Down Expand Up @@ -186,7 +187,7 @@ fn apply_hash_join(
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
JoinType::Inner,
*join_type,
);
return vec![node.into_plan_node().into()];
}
Expand Down Expand Up @@ -244,7 +245,235 @@ fn apply_hash_join(
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
JoinType::Inner,
*join_type,
);
return vec![node.into_plan_node().into()];
}
_ => {}
}
vec![]
}

define_impl_rule!(
HashJoinLeftOuterRule,
apply_hash_join_left_outer,
(Join(JoinType::LeftOuter), left, right)
);

fn apply_hash_join_left_outer(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding).unwrap();
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
let left_expr = op.left_child();
let right_expr = op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};

if can_convert {
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
*join_type,
);
return vec![node.into_plan_node().into()];
}
}
DfPredType::LogOp(LogOpType::And) => {
// currently only support consecutive equal queries
let mut is_consecutive_eq = true;
for child in cond.children.clone() {
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
continue;
} else {
is_consecutive_eq = false;
break;
}
}
if !is_consecutive_eq {
return vec![];
}

let left_schema = optimizer.get_schema_of(left.clone());
let mut left_exprs = vec![];
let mut right_exprs = vec![];
for child in &cond.children {
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
let left_expr = bin_op.left_child();
let right_expr = bin_op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};
if !can_convert {
return vec![];
}
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
right_exprs.push(right_expr.into_pred_node());
left_exprs.push(left_expr.into_pred_node());
}

let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
*join_type,
);
return vec![node.into_plan_node().into()];
}
_ => {}
}
vec![]
}

define_impl_rule!(
HashJoinLeftMarkRule,
apply_hash_join_left_mark,
(Join(JoinType::LeftMark), left, right)
);

fn apply_hash_join_left_mark(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding).unwrap();
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
let left_expr = op.left_child();
let right_expr = op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};

if can_convert {
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
*join_type,
);
return vec![node.into_plan_node().into()];
}
}
DfPredType::LogOp(LogOpType::And) => {
// currently only support consecutive equal queries
let mut is_consecutive_eq = true;
for child in cond.children.clone() {
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
continue;
} else {
is_consecutive_eq = false;
break;
}
}
if !is_consecutive_eq {
return vec![];
}

let left_schema = optimizer.get_schema_of(left.clone());
let mut left_exprs = vec![];
let mut right_exprs = vec![];
for child in &cond.children {
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
let left_expr = bin_op.left_child();
let right_expr = bin_op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};
if !can_convert {
return vec![];
}
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
right_exprs.push(right_expr.into_pred_node());
left_exprs.push(left_expr.into_pred_node());
}

let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
*join_type,
);
return vec![node.into_plan_node().into()];
}
Expand Down

0 comments on commit cecc8a3

Please sign in to comment.