From 02e2267c796129245f6f72547930257b3b9c02d7 Mon Sep 17 00:00:00 2001 From: AveryQi115 Date: Mon, 12 Feb 2024 14:01:54 -0500 Subject: [PATCH 1/3] align schema Signed-off-by: AveryQi115 --- datafusion-optd-cli/tests/cli_integration.rs | 7 ++-- optd-core/src/cascades/optimizer.rs | 4 +-- .../src/cascades/tasks/optimize_inputs.rs | 7 +--- optd-core/src/heuristics/optimizer.rs | 5 +-- optd-datafusion-bridge/src/from_optd.rs | 29 +++++++++------- optd-datafusion-bridge/src/lib.rs | 18 +++++++--- optd-datafusion-repr/src/properties/schema.rs | 33 +++++++++++++++---- optd-datafusion-repr/src/rules/joins.rs | 14 ++++---- 8 files changed, 72 insertions(+), 45 deletions(-) diff --git a/datafusion-optd-cli/tests/cli_integration.rs b/datafusion-optd-cli/tests/cli_integration.rs index 8a9f7cee..663fa150 100644 --- a/datafusion-optd-cli/tests/cli_integration.rs +++ b/datafusion-optd-cli/tests/cli_integration.rs @@ -57,5 +57,8 @@ fn cli_test_tpch() { cmd.current_dir(".."); // all paths in `test.sql` assume we're in the base dir of the repo cmd.args(["--enable-logical", "--file", "tpch/test.sql"]); let status = cmd.status().unwrap(); - assert!(status.success(), "should not have crashed when running tpch"); -} \ No newline at end of file + assert!( + status.success(), + "should not have crashed when running tpch" + ); +} diff --git a/optd-core/src/cascades/optimizer.rs b/optd-core/src/cascades/optimizer.rs index c3871cdd..ad1e4c4b 100644 --- a/optd-core/src/cascades/optimizer.rs +++ b/optd-core/src/cascades/optimizer.rs @@ -214,9 +214,7 @@ impl CascadesOptimizer { group_id: GroupId, mut on_produce: impl FnMut(RelNodeRef, GroupId) -> RelNodeRef, ) -> Result> { - self - .memo - .get_best_group_binding(group_id, &mut on_produce) + self.memo.get_best_group_binding(group_id, &mut on_produce) } fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> { diff --git a/optd-core/src/cascades/tasks/optimize_inputs.rs b/optd-core/src/cascades/tasks/optimize_inputs.rs index 6cb6b85a..f6b35ca7 100644 --- a/optd-core/src/cascades/tasks/optimize_inputs.rs +++ b/optd-core/src/cascades/tasks/optimize_inputs.rs @@ -243,12 +243,7 @@ impl Task for OptimizeInputsTask { } else { self.update_winner( &cost.sum( - &cost.compute_cost( - &expr.typ, - &expr.data, - &input_cost, - Some(context), - ), + &cost.compute_cost(&expr.typ, &expr.data, &input_cost, Some(context)), &input_cost, ), optimizer, diff --git a/optd-core/src/heuristics/optimizer.rs b/optd-core/src/heuristics/optimizer.rs index 4a081a0d..2c1b1bf8 100644 --- a/optd-core/src/heuristics/optimizer.rs +++ b/optd-core/src/heuristics/optimizer.rs @@ -49,10 +49,7 @@ fn match_node( assert!(res.is_none(), "dup pick"); } RuleMatcher::PickMany { pick_to } => { - let res = pick.insert( - *pick_to, - RelNode::new_list(node.children[idx..].to_vec()), - ); + let res = pick.insert(*pick_to, RelNode::new_list(node.children[idx..].to_vec())); assert!(res.is_none(), "dup pick"); should_end = true; } diff --git a/optd-datafusion-bridge/src/from_optd.rs b/optd-datafusion-bridge/src/from_optd.rs index a1968200..191199f6 100644 --- a/optd-datafusion-bridge/src/from_optd.rs +++ b/optd-datafusion-bridge/src/from_optd.rs @@ -36,7 +36,7 @@ use crate::{physical_collector::CollectorExec, OptdPlanContext}; // TODO: current DataType and ConstantType are not 1 to 1 mapping // optd schema stores constantType from data type in catalog.get // for decimal128, the precision is lost -fn from_optd_schema(optd_schema: &OptdSchema) -> Schema { +fn from_optd_schema(optd_schema: OptdSchema) -> Schema { let match_type = |typ: &ConstantType| match typ { ConstantType::Any => unimplemented!(), ConstantType::Bool => DataType::Boolean, @@ -52,12 +52,15 @@ fn from_optd_schema(optd_schema: &OptdSchema) -> Schema { ConstantType::Decimal => DataType::Float64, ConstantType::Utf8String => DataType::Utf8, }; - let fields: Vec<_> = optd_schema - .0 - .iter() - .enumerate() - .map(|(i, typ)| Field::new(format!("c{}", i), match_type(typ), false)) - .collect(); + let mut fields = vec![]; + fields.reserve(optd_schema.len()); + for field in optd_schema.fields { + fields.push(Field::new( + field.name, + match_type(&field.typ), + field.nullable, + )); + } Schema::new(fields) } @@ -351,7 +354,8 @@ impl OptdPlanContext<'_> { Schema::new_with_metadata(fields, HashMap::new()) }; - let physical_expr = Self::conv_from_optd_expr(node.cond(), &Arc::new(filter_schema.clone()))?; + let physical_expr = + Self::conv_from_optd_expr(node.cond(), &Arc::new(filter_schema.clone()))?; if let JoinType::Cross = node.join_type() { return Ok(Arc::new(CrossJoinExec::new(left_exec, right_exec)) @@ -436,7 +440,7 @@ impl OptdPlanContext<'_> { #[async_recursion] async fn conv_from_optd_plan_node(&mut self, node: PlanNode) -> Result> { - let mut schema = OptdSchema(vec![]); + let mut schema = OptdSchema { fields: vec![] }; if node.typ() == OptRelNodeTyp::PhysicalEmptyRelation { schema = node.schema(self.optimizer.unwrap().optd_optimizer()); } @@ -484,7 +488,7 @@ impl OptdPlanContext<'_> { } OptRelNodeTyp::PhysicalEmptyRelation => { let physical_node = PhysicalEmptyRelation::from_rel_node(rel_node).unwrap(); - let datafusion_schema: Schema = from_optd_schema(&schema); + let datafusion_schema: Schema = from_optd_schema(schema); Ok(Arc::new(datafusion::physical_plan::empty::EmptyExec::new( physical_node.produce_one_row(), Arc::new(datafusion_schema), @@ -495,7 +499,10 @@ impl OptdPlanContext<'_> { result.with_context(|| format!("when processing {}", rel_node_dbg)) } - pub async fn conv_from_optd(&mut self, root_rel: OptRelNodeRef) -> Result> { + pub async fn conv_from_optd( + &mut self, + root_rel: OptRelNodeRef, + ) -> Result> { self.conv_from_optd_plan_node(PlanNode::from_rel_node(root_rel).unwrap()) .await } diff --git a/optd-datafusion-bridge/src/lib.rs b/optd-datafusion-bridge/src/lib.rs index cac9074d..a4804ac6 100644 --- a/optd-datafusion-bridge/src/lib.rs +++ b/optd-datafusion-bridge/src/lib.rs @@ -61,9 +61,11 @@ impl Catalog for DatafusionCatalog { let catalog = self.catalog.catalog("datafusion").unwrap(); let schema = catalog.schema("public").unwrap(); let table = futures_lite::future::block_on(schema.table(name.as_ref())).unwrap(); - let fields = table.schema(); - let mut optd_schema = vec![]; - for field in fields.fields() { + let schema = table.schema(); + let fields = schema.fields(); + let mut optd_fields = vec![]; + optd_fields.reserve(fields.len()); + for field in fields { let dt = match field.data_type() { DataType::Date32 => ConstantType::Date, DataType::Int32 => ConstantType::Int32, @@ -73,9 +75,15 @@ impl Catalog for DatafusionCatalog { DataType::Decimal128(_, _) => ConstantType::Decimal, dt => unimplemented!("{:?}", dt), }; - optd_schema.push(dt); + optd_fields.push(optd_datafusion_repr::properties::schema::Field { + name: field.name().to_string(), + typ: dt, + nullable: field.is_nullable(), + }); + } + optd_datafusion_repr::properties::schema::Schema { + fields: optd_fields, } - optd_datafusion_repr::properties::schema::Schema(optd_schema) } } diff --git a/optd-datafusion-repr/src/properties/schema.rs b/optd-datafusion-repr/src/properties/schema.rs index 09ff2eee..cf1b376d 100644 --- a/optd-datafusion-repr/src/properties/schema.rs +++ b/optd-datafusion-repr/src/properties/schema.rs @@ -3,12 +3,19 @@ use optd_core::property::PropertyBuilder; use crate::plan_nodes::{ConstantType, OptRelNodeTyp}; #[derive(Clone, Debug)] -pub struct Schema(pub Vec); +pub struct Field { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} +#[derive(Clone, Debug)] +pub struct Schema { + pub fields: Vec, +} -// TODO: add names, nullable to schema impl Schema { pub fn len(&self) -> usize { - self.0.len() + self.fields.len() } pub fn is_empty(&self) -> bool { @@ -48,11 +55,25 @@ impl PropertyBuilder for SchemaPropertyBuilder { OptRelNodeTyp::Filter => children[0].clone(), OptRelNodeTyp::Join(_) => { let mut schema = children[0].clone(); - schema.0.extend(children[1].clone().0); + let schema2 = children[1].clone(); + schema.fields.extend(schema2.fields); + schema + } + OptRelNodeTyp::List => { + // TODO: calculate real is_nullable for aggregations + let schema = Schema { + fields: vec![ + Field { + name: "unnamed".to_string(), + typ: ConstantType::Any, + nullable: true + }; + children.len() + ], + }; schema } - OptRelNodeTyp::List => Schema(vec![ConstantType::Any; children.len()]), - _ => Schema(vec![]), + _ => Schema { fields: vec![] }, } } diff --git a/optd-datafusion-repr/src/rules/joins.rs b/optd-datafusion-repr/src/rules/joins.rs index a0314140..a6453e3c 100644 --- a/optd-datafusion-repr/src/rules/joins.rs +++ b/optd-datafusion-repr/src/rules/joins.rs @@ -68,7 +68,7 @@ fn apply_join_commute( cond, JoinType::Inner, ); - let mut proj_expr = Vec::with_capacity(left_schema.0.len() + right_schema.0.len()); + let mut proj_expr = Vec::with_capacity(left_schema.len() + right_schema.len()); for i in 0..left_schema.len() { proj_expr.push(ColumnRefExpr::new(right_schema.len() + i).into_expr()); } @@ -218,13 +218,11 @@ fn apply_hash_join( let Some(mut right_expr) = ColumnRefExpr::from_rel_node(right_expr.into_rel_node()) else { return vec![]; }; - let can_convert = if left_expr.index() < left_schema.0.len() - && right_expr.index() >= left_schema.0.len() + let can_convert = if left_expr.index() < left_schema.len() + && right_expr.index() >= left_schema.len() { true - } else if right_expr.index() < left_schema.0.len() - && left_expr.index() >= left_schema.0.len() - { + } else if right_expr.index() < left_schema.len() && left_expr.index() >= left_schema.len() { (left_expr, right_expr) = (right_expr, left_expr); true } else { @@ -232,7 +230,7 @@ fn apply_hash_join( }; if can_convert { - let right_expr = ColumnRefExpr::new(right_expr.index() - left_schema.0.len()); + let right_expr = ColumnRefExpr::new(right_expr.index() - left_schema.len()); let node = PhysicalHashJoin::new( PlanNode::from_group(left.into()), PlanNode::from_group(right.into()), @@ -342,7 +340,7 @@ fn apply_projection_pull_up_join( .into_rel_node(), ); } - + Expr::from_rel_node( RelNode { typ: expr.typ.clone(), From 98294a04699a2827260e31bae4ff5384e2771c26 Mon Sep 17 00:00:00 2001 From: AveryQi115 Date: Mon, 12 Feb 2024 14:09:10 -0500 Subject: [PATCH 2/3] fix fmt warning Signed-off-by: AveryQi115 --- optd-datafusion-repr/src/properties/schema.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optd-datafusion-repr/src/properties/schema.rs b/optd-datafusion-repr/src/properties/schema.rs index cf1b376d..6ef5872c 100644 --- a/optd-datafusion-repr/src/properties/schema.rs +++ b/optd-datafusion-repr/src/properties/schema.rs @@ -61,7 +61,7 @@ impl PropertyBuilder for SchemaPropertyBuilder { } OptRelNodeTyp::List => { // TODO: calculate real is_nullable for aggregations - let schema = Schema { + Schema { fields: vec![ Field { name: "unnamed".to_string(), @@ -70,8 +70,7 @@ impl PropertyBuilder for SchemaPropertyBuilder { }; children.len() ], - }; - schema + } } _ => Schema { fields: vec![] }, } From ee20069787b9050581d88e9734d71fdb3d81d682 Mon Sep 17 00:00:00 2001 From: AveryQi115 Date: Mon, 12 Feb 2024 14:16:46 -0500 Subject: [PATCH 3/3] update local rust version and fix fmt warning Signed-off-by: AveryQi115 --- optd-datafusion-bridge/src/from_optd.rs | 3 +-- optd-datafusion-bridge/src/lib.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/optd-datafusion-bridge/src/from_optd.rs b/optd-datafusion-bridge/src/from_optd.rs index 191199f6..df521716 100644 --- a/optd-datafusion-bridge/src/from_optd.rs +++ b/optd-datafusion-bridge/src/from_optd.rs @@ -52,8 +52,7 @@ fn from_optd_schema(optd_schema: OptdSchema) -> Schema { ConstantType::Decimal => DataType::Float64, ConstantType::Utf8String => DataType::Utf8, }; - let mut fields = vec![]; - fields.reserve(optd_schema.len()); + let mut fields = Vec::with_capacity(optd_schema.len()); for field in optd_schema.fields { fields.push(Field::new( field.name, diff --git a/optd-datafusion-bridge/src/lib.rs b/optd-datafusion-bridge/src/lib.rs index a4804ac6..601e79b8 100644 --- a/optd-datafusion-bridge/src/lib.rs +++ b/optd-datafusion-bridge/src/lib.rs @@ -63,8 +63,7 @@ impl Catalog for DatafusionCatalog { let table = futures_lite::future::block_on(schema.table(name.as_ref())).unwrap(); let schema = table.schema(); let fields = schema.fields(); - let mut optd_fields = vec![]; - optd_fields.reserve(fields.len()); + let mut optd_fields = Vec::with_capacity(fields.len()); for field in fields { let dt = match field.data_type() { DataType::Date32 => ConstantType::Date,