Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 3d01877

Browse files
authored
feat: support type expr (#73)
`DataTypeExpr` wraps `arrow_schema::DataType`.
1 parent e9188f9 commit 3d01877

File tree

8 files changed

+98
-65
lines changed

8 files changed

+98
-65
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

optd-datafusion-bridge/src/from_optd.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use optd_datafusion_repr::{
2929
PhysicalSort, PlanNode, SortOrderExpr, SortOrderType,
3030
},
3131
properties::schema::Schema as OptdSchema,
32-
PhysicalCollector, Value,
32+
PhysicalCollector,
3333
};
3434

3535
use crate::{physical_collector::CollectorExec, OptdPlanContext};
@@ -251,14 +251,12 @@ impl OptdPlanContext<'_> {
251251
OptRelNodeTyp::Cast => {
252252
let expr = CastExpr::from_rel_node(expr.into_rel_node()).unwrap();
253253
let child = Self::conv_from_optd_expr(expr.child(), context)?;
254-
let data_type = match expr.cast_to() {
255-
Value::Bool(_) => DataType::Boolean,
256-
Value::Decimal128(_) => DataType::Decimal128(15, 2), /* TODO: AVOID HARD CODE PRECISION */
257-
Value::Date32(_) => DataType::Date32,
258-
other => unimplemented!("{}", other),
259-
};
260254
Ok(Arc::new(
261-
datafusion::physical_plan::expressions::CastExpr::new(child, data_type, None),
255+
datafusion::physical_plan::expressions::CastExpr::new(
256+
child,
257+
expr.cast_to(),
258+
None,
259+
),
262260
))
263261
}
264262
OptRelNodeTyp::Like => {

optd-datafusion-bridge/src/into_optd.rs

+6-24
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@ use datafusion::{
66
};
77
use datafusion_expr::Expr as DFExpr;
88
use optd_core::rel_node::RelNode;
9-
use optd_datafusion_repr::{
10-
plan_nodes::{
11-
BetweenExpr, BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, Expr, ExprList,
12-
FuncExpr, FuncType, JoinType, LikeExpr, LogOpExpr, LogOpType, LogicalAgg,
13-
LogicalEmptyRelation, LogicalFilter, LogicalJoin, LogicalLimit, LogicalProjection,
14-
LogicalScan, LogicalSort, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode,
15-
SortOrderExpr, SortOrderType,
16-
},
17-
Value,
9+
use optd_datafusion_repr::plan_nodes::{
10+
BetweenExpr, BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, Expr, ExprList,
11+
FuncExpr, FuncType, JoinType, LikeExpr, LogOpExpr, LogOpType, LogicalAgg, LogicalEmptyRelation,
12+
LogicalFilter, LogicalJoin, LogicalLimit, LogicalProjection, LogicalScan, LogicalSort,
13+
OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode, SortOrderExpr, SortOrderType,
1814
};
1915

2016
use crate::OptdPlanContext;
@@ -171,21 +167,7 @@ impl OptdPlanContext<'_> {
171167
}
172168
Expr::Cast(x) => {
173169
let expr = self.conv_into_optd_expr(x.expr.as_ref(), context)?;
174-
let data_type = x.data_type.clone();
175-
let val = match data_type {
176-
arrow_schema::DataType::Int8 => Value::Int8(0),
177-
arrow_schema::DataType::Int16 => Value::Int16(0),
178-
arrow_schema::DataType::Int32 => Value::Int32(0),
179-
arrow_schema::DataType::Int64 => Value::Int64(0),
180-
arrow_schema::DataType::UInt8 => Value::UInt8(0),
181-
arrow_schema::DataType::UInt16 => Value::UInt16(0),
182-
arrow_schema::DataType::UInt32 => Value::UInt32(0),
183-
arrow_schema::DataType::UInt64 => Value::UInt64(0),
184-
arrow_schema::DataType::Date32 => Value::Date32(0),
185-
arrow_schema::DataType::Decimal128(_, _) => Value::Decimal128(0),
186-
other => unimplemented!("unimplemented datatype {:?}", other),
187-
};
188-
Ok(CastExpr::new(expr, val).into_expr())
170+
Ok(CastExpr::new(expr, x.data_type.clone()).into_expr())
189171
}
190172
Expr::Like(x) => {
191173
let expr = self.conv_into_optd_expr(x.expr.as_ref(), context)?;

optd-datafusion-repr/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = "2021"
77

88
[dependencies]
99
anyhow = "1"
10+
arrow-schema = "47.0.0"
1011
num-traits = "0.2"
1112
num-derive = "0.2"
1213
tracing = "0.1"

optd-datafusion-repr/src/plan_nodes.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ mod sort;
1414

1515
use std::sync::Arc;
1616

17+
use arrow_schema::DataType;
1718
use optd_core::{
1819
cascades::{CascadesOptimizer, GroupId},
1920
rel_node::{RelNode, RelNodeRef, RelNodeTyp},
@@ -24,8 +25,8 @@ pub use apply::{ApplyType, LogicalApply};
2425
pub use empty_relation::{LogicalEmptyRelation, PhysicalEmptyRelation};
2526
pub use expr::{
2627
BetweenExpr, BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, ConstantType,
27-
ExprList, FuncExpr, FuncType, LikeExpr, LogOpExpr, LogOpType, SortOrderExpr, SortOrderType,
28-
UnOpExpr, UnOpType,
28+
DataTypeExpr, ExprList, FuncExpr, FuncType, LikeExpr, LogOpExpr, LogOpType, SortOrderExpr,
29+
SortOrderType, UnOpExpr, UnOpType,
2930
};
3031
pub use filter::{LogicalFilter, PhysicalFilter};
3132
pub use join::{JoinType, LogicalJoin, PhysicalHashJoin, PhysicalNestedLoopJoin};
@@ -77,6 +78,7 @@ pub enum OptRelNodeTyp {
7778
Between,
7879
Cast,
7980
Like,
81+
DataType(DataType),
8082
}
8183

8284
impl OptRelNodeTyp {
@@ -118,6 +120,7 @@ impl OptRelNodeTyp {
118120
| Self::Between
119121
| Self::Cast
120122
| Self::Like
123+
| Self::DataType(_)
121124
)
122125
}
123126
}
@@ -387,6 +390,9 @@ pub fn explain(rel_node: OptRelNodeRef) -> Pretty<'static> {
387390
OptRelNodeTyp::Like => LikeExpr::from_rel_node(rel_node)
388391
.unwrap()
389392
.dispatch_explain(),
393+
OptRelNodeTyp::DataType(_) => DataTypeExpr::from_rel_node(rel_node)
394+
.unwrap()
395+
.dispatch_explain(),
390396
}
391397
}
392398

optd-datafusion-repr/src/plan_nodes/expr.rs

+53-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{fmt::Display, sync::Arc};
22

3+
use arrow_schema::DataType;
34
use itertools::Itertools;
45
use pretty_xmlish::Pretty;
56

@@ -636,19 +637,60 @@ impl OptRelNode for BetweenExpr {
636637
}
637638
}
638639

640+
#[derive(Clone, Debug)]
641+
pub struct DataTypeExpr(pub Expr);
642+
643+
impl DataTypeExpr {
644+
pub fn new(typ: DataType) -> Self {
645+
DataTypeExpr(Expr(
646+
RelNode {
647+
typ: OptRelNodeTyp::DataType(typ),
648+
children: vec![],
649+
data: None,
650+
}
651+
.into(),
652+
))
653+
}
654+
655+
pub fn data_type(&self) -> DataType {
656+
if let OptRelNodeTyp::DataType(data_type) = self.0.typ() {
657+
data_type
658+
} else {
659+
panic!("not a data type")
660+
}
661+
}
662+
}
663+
664+
impl OptRelNode for DataTypeExpr {
665+
fn into_rel_node(self) -> OptRelNodeRef {
666+
self.0.into_rel_node()
667+
}
668+
669+
fn from_rel_node(rel_node: OptRelNodeRef) -> Option<Self> {
670+
if !matches!(rel_node.typ, OptRelNodeTyp::DataType(_)) {
671+
return None;
672+
}
673+
Expr::from_rel_node(rel_node).map(Self)
674+
}
675+
676+
fn dispatch_explain(&self) -> Pretty<'static> {
677+
Pretty::display(&self.data_type().to_string())
678+
}
679+
}
680+
639681
#[derive(Clone, Debug)]
640682
pub struct CastExpr(pub Expr);
641683

642684
impl CastExpr {
643-
pub fn new(
644-
expr: Expr,
645-
cast_to: Value, /* TODO: have a `type` relnode for representing type */
646-
) -> Self {
685+
pub fn new(expr: Expr, cast_to: DataType) -> Self {
647686
CastExpr(Expr(
648687
RelNode {
649688
typ: OptRelNodeTyp::Cast,
650-
children: vec![expr.into_rel_node()],
651-
data: Some(cast_to),
689+
children: vec![
690+
expr.into_rel_node(),
691+
DataTypeExpr::new(cast_to).into_rel_node(),
692+
],
693+
data: None,
652694
}
653695
.into(),
654696
))
@@ -658,8 +700,10 @@ impl CastExpr {
658700
Expr(self.0.child(0))
659701
}
660702

661-
pub fn cast_to(&self) -> Value {
662-
self.0 .0.data.clone().unwrap()
703+
pub fn cast_to(&self) -> DataType {
704+
DataTypeExpr::from_rel_node(self.0.child(1))
705+
.unwrap()
706+
.data_type()
663707
}
664708
}
665709

@@ -679,7 +723,7 @@ impl OptRelNode for CastExpr {
679723
Pretty::simple_record(
680724
"Cast",
681725
vec![
682-
("cast_to", format!("{:?}", self.cast_to()).into()),
726+
("cast_to", format!("{}", self.cast_to()).into()),
683727
("expr", self.child().explain()),
684728
],
685729
vec![],

optd-datafusion-repr/src/properties/column_ref.rs

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ impl PropertyBuilder<OptRelNodeTyp> for ColumnRefPropertyBuilder {
106106
OptRelNodeTyp::Constant(_)
107107
| OptRelNodeTyp::Func(_)
108108
| OptRelNodeTyp::BinOp(_)
109+
| OptRelNodeTyp::DataType(_)
109110
| OptRelNodeTyp::Between
110111
| OptRelNodeTyp::EmptyRelation
111112
| OptRelNodeTyp::Like => {

optd-sqlplannertest/tests/tpch.planner.sql

+22-22
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ LogicalSort
124124
│ └── Mul
125125
│ ├── #22
126126
│ └── Sub
127-
│ ├── Cast { cast_to: Decimal128(0), expr: 1 }
127+
│ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
128128
│ └── #23
129129
├── groups: [ #41 ]
130130
└── LogicalFilter
@@ -159,10 +159,10 @@ LogicalSort
159159
│ │ │ └── "Asia"
160160
│ │ └── Geq
161161
│ │ ├── #12
162-
│ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
162+
│ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
163163
│ └── Lt
164164
│ ├── #12
165-
│ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
165+
│ └── Cast { cast_to: Date32, expr: "2024-01-01" }
166166
└── LogicalJoin { join_type: Cross, cond: true }
167167
├── LogicalJoin { join_type: Cross, cond: true }
168168
│ ├── LogicalJoin { join_type: Cross, cond: true }
@@ -183,7 +183,7 @@ PhysicalSort
183183
│ └── Mul
184184
│ ├── #22
185185
│ └── Sub
186-
│ ├── Cast { cast_to: Decimal128(0), expr: 1 }
186+
│ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
187187
│ └── #23
188188
├── groups: [ #41 ]
189189
└── PhysicalFilter
@@ -218,10 +218,10 @@ PhysicalSort
218218
│ │ │ └── "Asia"
219219
│ │ └── Geq
220220
│ │ ├── #12
221-
│ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
221+
│ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
222222
│ └── Lt
223223
│ ├── #12
224-
│ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
224+
│ └── Cast { cast_to: Date32, expr: "2024-01-01" }
225225
└── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
226226
├── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
227227
│ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
@@ -260,14 +260,14 @@ LogicalProjection { exprs: [ #0 ] }
260260
│ │ ├── And
261261
│ │ │ ├── Geq
262262
│ │ │ │ ├── #10
263-
│ │ │ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
263+
│ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
264264
│ │ │ └── Lt
265265
│ │ │ ├── #10
266-
│ │ │ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
267-
│ │ └── Between { expr: Cast { cast_to: Decimal128(0), expr: #6 }, lower: Cast { cast_to: Decimal128(0), expr: 0.05 }, upper: Cast { cast_to: Decimal128(0), expr: 0.07 } }
266+
│ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" }
267+
│ │ └── Between { expr: Cast { cast_to: Decimal128(30, 15), expr: #6 }, lower: Cast { cast_to: Decimal128(30, 15), expr: 0.05 }, upper: Cast { cast_to: Decimal128(30, 15), expr: 0.07 } }
268268
│ └── Lt
269-
│ ├── Cast { cast_to: Decimal128(0), expr: #4 }
270-
│ └── Cast { cast_to: Decimal128(0), expr: 24 }
269+
│ ├── Cast { cast_to: Decimal128(22, 2), expr: #4 }
270+
│ └── Cast { cast_to: Decimal128(22, 2), expr: 24 }
271271
└── LogicalScan { table: lineitem }
272272
PhysicalProjection { exprs: [ #0 ] }
273273
└── PhysicalAgg
@@ -282,14 +282,14 @@ PhysicalProjection { exprs: [ #0 ] }
282282
│ │ ├── And
283283
│ │ │ ├── Geq
284284
│ │ │ │ ├── #10
285-
│ │ │ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
285+
│ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
286286
│ │ │ └── Lt
287287
│ │ │ ├── #10
288-
│ │ │ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
289-
│ │ └── Between { expr: Cast { cast_to: Decimal128(0), expr: #6 }, lower: Cast { cast_to: Decimal128(0), expr: 0.05 }, upper: Cast { cast_to: Decimal128(0), expr: 0.07 } }
288+
│ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" }
289+
│ │ └── Between { expr: Cast { cast_to: Decimal128(30, 15), expr: #6 }, lower: Cast { cast_to: Decimal128(30, 15), expr: 0.05 }, upper: Cast { cast_to: Decimal128(30, 15), expr: 0.07 } }
290290
│ └── Lt
291-
│ ├── Cast { cast_to: Decimal128(0), expr: #4 }
292-
│ └── Cast { cast_to: Decimal128(0), expr: 24 }
291+
│ ├── Cast { cast_to: Decimal128(22, 2), expr: #4 }
292+
│ └── Cast { cast_to: Decimal128(22, 2), expr: 24 }
293293
└── PhysicalScan { table: lineitem }
294294
*/
295295

@@ -541,7 +541,7 @@ LogicalSort
541541
│ │ │ ├── #2
542542
│ │ │ └── "IRAQ"
543543
│ │ ├── #1
544-
│ │ └── Cast { cast_to: Decimal128(0), expr: 0 }
544+
│ │ └── Cast { cast_to: Decimal128(38, 4), expr: 0 }
545545
│ └── Agg(Sum)
546546
│ └── [ #1 ]
547547
├── groups: [ #0 ]
@@ -552,7 +552,7 @@ LogicalSort
552552
│ ├── Mul
553553
│ │ ├── #21
554554
│ │ └── Sub
555-
│ │ ├── Cast { cast_to: Decimal128(0), expr: 1 }
555+
│ │ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
556556
│ │ └── #22
557557
│ └── #54
558558
└── LogicalFilter
@@ -589,7 +589,7 @@ LogicalSort
589589
│ │ │ └── Eq
590590
│ │ │ ├── #12
591591
│ │ │ └── #53
592-
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32(0), expr: "1995-01-01" }, upper: Cast { cast_to: Date32(0), expr: "1996-12-31" } }
592+
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } }
593593
│ └── Eq
594594
│ ├── #4
595595
│ └── "ECONOMY ANODIZED STEEL"
@@ -626,7 +626,7 @@ PhysicalSort
626626
│ │ │ ├── #2
627627
│ │ │ └── "IRAQ"
628628
│ │ ├── #1
629-
│ │ └── Cast { cast_to: Decimal128(0), expr: 0 }
629+
│ │ └── Cast { cast_to: Decimal128(38, 4), expr: 0 }
630630
│ └── Agg(Sum)
631631
│ └── [ #1 ]
632632
├── groups: [ #0 ]
@@ -637,7 +637,7 @@ PhysicalSort
637637
│ ├── Mul
638638
│ │ ├── #21
639639
│ │ └── Sub
640-
│ │ ├── Cast { cast_to: Decimal128(0), expr: 1 }
640+
│ │ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
641641
│ │ └── #22
642642
│ └── #54
643643
└── PhysicalFilter
@@ -674,7 +674,7 @@ PhysicalSort
674674
│ │ │ └── Eq
675675
│ │ │ ├── #12
676676
│ │ │ └── #53
677-
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32(0), expr: "1995-01-01" }, upper: Cast { cast_to: Date32(0), expr: "1996-12-31" } }
677+
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } }
678678
│ └── Eq
679679
│ ├── #4
680680
│ └── "ECONOMY ANODIZED STEEL"

0 commit comments

Comments
 (0)