Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 973e80c

Browse files
committed
use define_impl_rule_discriminant
Signed-off-by: Yuchen Liang <[email protected]>
1 parent fa66727 commit 973e80c

File tree

3 files changed

+18
-240
lines changed

3 files changed

+18
-240
lines changed

optd-datafusion-repr/src/lib.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ impl DatafusionOptimizer {
105105
rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new()));
106106
rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new()));
107107
rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new()));
108-
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
109-
rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new()));
110-
rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new()));
108+
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
111109
rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new()));
112110
rule_wrappers.push(Arc::new(rules::JoinAssocRule::new()));
113111
rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new()));
@@ -179,7 +177,7 @@ impl DatafusionOptimizer {
179177
for rule in rules {
180178
rule_wrappers.push(rule);
181179
}
182-
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
180+
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
183181
rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new()));
184182
rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new()));
185183
rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new()));

optd-datafusion-repr/src/rules/joins.rs

Lines changed: 6 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use optd_core::nodes::PlanNodeOrGroup;
99
use optd_core::optimizer::Optimizer;
1010
use optd_core::rules::{Rule, RuleMatcher};
1111

12-
use super::macros::{define_impl_rule, define_rule};
12+
use super::macros::{define_impl_rule_discriminant, define_rule};
1313
use crate::plan_nodes::{
1414
ArcDfPlanNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, ConstantType, DfNodeType,
1515
DfPredType, DfReprPlanNode, DfReprPredNode, JoinType, ListPred, LogOpType,
@@ -140,241 +140,14 @@ fn apply_join_assoc(
140140
vec![node.into_plan_node().into()]
141141
}
142142

143-
define_impl_rule!(
144-
HashJoinInnerRule,
145-
apply_hash_join_inner,
143+
// Note: this matches all join types despite using `JoinType::Inner` below.
144+
define_impl_rule_discriminant!(
145+
HashJoinRule,
146+
apply_hash_join,
146147
(Join(JoinType::Inner), left, right)
147148
);
148149

149-
fn apply_hash_join_inner(
150-
optimizer: &impl Optimizer<DfNodeType>,
151-
binding: ArcDfPlanNode,
152-
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
153-
let join = LogicalJoin::from_plan_node(binding).unwrap();
154-
let cond = join.cond();
155-
let left = join.left();
156-
let right = join.right();
157-
let join_type = join.join_type();
158-
match cond.typ {
159-
DfPredType::BinOp(BinOpType::Eq) => {
160-
let left_schema = optimizer.get_schema_of(left.clone());
161-
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
162-
let left_expr = op.left_child();
163-
let right_expr = op.right_child();
164-
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
165-
return vec![];
166-
};
167-
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
168-
return vec![];
169-
};
170-
let can_convert = if left_expr.index() < left_schema.len()
171-
&& right_expr.index() >= left_schema.len()
172-
{
173-
true
174-
} else if right_expr.index() < left_schema.len()
175-
&& left_expr.index() >= left_schema.len()
176-
{
177-
(left_expr, right_expr) = (right_expr, left_expr);
178-
true
179-
} else {
180-
false
181-
};
182-
183-
if can_convert {
184-
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
185-
let node = PhysicalHashJoin::new_unchecked(
186-
left,
187-
right,
188-
ListPred::new(vec![left_expr.into_pred_node()]),
189-
ListPred::new(vec![right_expr.into_pred_node()]),
190-
*join_type,
191-
);
192-
return vec![node.into_plan_node().into()];
193-
}
194-
}
195-
DfPredType::LogOp(LogOpType::And) => {
196-
// currently only support consecutive equal queries
197-
let mut is_consecutive_eq = true;
198-
for child in cond.children.clone() {
199-
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
200-
continue;
201-
} else {
202-
is_consecutive_eq = false;
203-
break;
204-
}
205-
}
206-
if !is_consecutive_eq {
207-
return vec![];
208-
}
209-
210-
let left_schema = optimizer.get_schema_of(left.clone());
211-
let mut left_exprs = vec![];
212-
let mut right_exprs = vec![];
213-
for child in &cond.children {
214-
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
215-
let left_expr = bin_op.left_child();
216-
let right_expr = bin_op.right_child();
217-
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
218-
return vec![];
219-
};
220-
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
221-
return vec![];
222-
};
223-
let can_convert = if left_expr.index() < left_schema.len()
224-
&& right_expr.index() >= left_schema.len()
225-
{
226-
true
227-
} else if right_expr.index() < left_schema.len()
228-
&& left_expr.index() >= left_schema.len()
229-
{
230-
(left_expr, right_expr) = (right_expr, left_expr);
231-
true
232-
} else {
233-
false
234-
};
235-
if !can_convert {
236-
return vec![];
237-
}
238-
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
239-
right_exprs.push(right_expr.into_pred_node());
240-
left_exprs.push(left_expr.into_pred_node());
241-
}
242-
243-
let node = PhysicalHashJoin::new_unchecked(
244-
left,
245-
right,
246-
ListPred::new(left_exprs),
247-
ListPred::new(right_exprs),
248-
*join_type,
249-
);
250-
return vec![node.into_plan_node().into()];
251-
}
252-
_ => {}
253-
}
254-
vec![]
255-
}
256-
257-
define_impl_rule!(
258-
HashJoinLeftOuterRule,
259-
apply_hash_join_left_outer,
260-
(Join(JoinType::LeftOuter), left, right)
261-
);
262-
263-
fn apply_hash_join_left_outer(
264-
optimizer: &impl Optimizer<DfNodeType>,
265-
binding: ArcDfPlanNode,
266-
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
267-
let join = LogicalJoin::from_plan_node(binding).unwrap();
268-
let cond = join.cond();
269-
let left = join.left();
270-
let right = join.right();
271-
let join_type = join.join_type();
272-
match cond.typ {
273-
DfPredType::BinOp(BinOpType::Eq) => {
274-
let left_schema = optimizer.get_schema_of(left.clone());
275-
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
276-
let left_expr = op.left_child();
277-
let right_expr = op.right_child();
278-
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
279-
return vec![];
280-
};
281-
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
282-
return vec![];
283-
};
284-
let can_convert = if left_expr.index() < left_schema.len()
285-
&& right_expr.index() >= left_schema.len()
286-
{
287-
true
288-
} else if right_expr.index() < left_schema.len()
289-
&& left_expr.index() >= left_schema.len()
290-
{
291-
(left_expr, right_expr) = (right_expr, left_expr);
292-
true
293-
} else {
294-
false
295-
};
296-
297-
if can_convert {
298-
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
299-
let node = PhysicalHashJoin::new_unchecked(
300-
left,
301-
right,
302-
ListPred::new(vec![left_expr.into_pred_node()]),
303-
ListPred::new(vec![right_expr.into_pred_node()]),
304-
*join_type,
305-
);
306-
return vec![node.into_plan_node().into()];
307-
}
308-
}
309-
DfPredType::LogOp(LogOpType::And) => {
310-
// currently only support consecutive equal queries
311-
let mut is_consecutive_eq = true;
312-
for child in cond.children.clone() {
313-
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
314-
continue;
315-
} else {
316-
is_consecutive_eq = false;
317-
break;
318-
}
319-
}
320-
if !is_consecutive_eq {
321-
return vec![];
322-
}
323-
324-
let left_schema = optimizer.get_schema_of(left.clone());
325-
let mut left_exprs = vec![];
326-
let mut right_exprs = vec![];
327-
for child in &cond.children {
328-
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
329-
let left_expr = bin_op.left_child();
330-
let right_expr = bin_op.right_child();
331-
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
332-
return vec![];
333-
};
334-
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
335-
return vec![];
336-
};
337-
let can_convert = if left_expr.index() < left_schema.len()
338-
&& right_expr.index() >= left_schema.len()
339-
{
340-
true
341-
} else if right_expr.index() < left_schema.len()
342-
&& left_expr.index() >= left_schema.len()
343-
{
344-
(left_expr, right_expr) = (right_expr, left_expr);
345-
true
346-
} else {
347-
false
348-
};
349-
if !can_convert {
350-
return vec![];
351-
}
352-
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
353-
right_exprs.push(right_expr.into_pred_node());
354-
left_exprs.push(left_expr.into_pred_node());
355-
}
356-
357-
let node = PhysicalHashJoin::new_unchecked(
358-
left,
359-
right,
360-
ListPred::new(left_exprs),
361-
ListPred::new(right_exprs),
362-
*join_type,
363-
);
364-
return vec![node.into_plan_node().into()];
365-
}
366-
_ => {}
367-
}
368-
vec![]
369-
}
370-
371-
define_impl_rule!(
372-
HashJoinLeftMarkRule,
373-
apply_hash_join_left_mark,
374-
(Join(JoinType::LeftMark), left, right)
375-
);
376-
377-
fn apply_hash_join_left_mark(
150+
fn apply_hash_join(
378151
optimizer: &impl Optimizer<DfNodeType>,
379152
binding: ArcDfPlanNode,
380153
) -> Vec<PlanNodeOrGroup<DfNodeType>> {

optd-datafusion-repr/src/rules/macros.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,19 @@ macro_rules! define_rule_discriminant {
7979
};
8080
}
8181

82-
macro_rules! define_impl_rule {
82+
// macro_rules! define_impl_rule {
83+
// ($name:ident, $apply:ident, $($matcher:tt)+) => {
84+
// crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ }
85+
// };
86+
// }
87+
88+
macro_rules! define_impl_rule_discriminant {
8389
($name:ident, $apply:ident, $($matcher:tt)+) => {
84-
crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ }
90+
crate::rules::macros::define_rule_inner! { true, true, $name, $apply, $($matcher)+ }
8591
};
8692
}
8793

8894
pub(crate) use {
89-
define_impl_rule, define_matcher, define_rule, define_rule_discriminant, define_rule_inner,
95+
define_impl_rule_discriminant, define_matcher, define_rule, define_rule_discriminant,
96+
define_rule_inner,
9097
};

0 commit comments

Comments
 (0)