Skip to content

Commit

Permalink
Bugfix: Add functional dependency check and aggregate try_new schema (a…
Browse files Browse the repository at this point in the history
…pache#8584)

* Add functional dependency check and aggregate try_new schema

* Update comments, make implementation idiomatic

* Use constraint during stream table initialization
  • Loading branch information
mustafasrepo authored Dec 20, 2023
1 parent 1bcaac4 commit 6f5230f
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 11 deletions.
16 changes: 16 additions & 0 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,22 @@ impl DFSchema {
.collect()
}

/// Find all fields indices having the given qualifier
pub fn fields_indices_with_qualified(
&self,
qualifier: &TableReference,
) -> Vec<usize> {
self.fields
.iter()
.enumerate()
.filter_map(|(idx, field)| {
field
.qualifier()
.and_then(|q| q.eq(qualifier).then_some(idx))
})
.collect()
}

/// Find all fields match the given name
pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> {
self.fields
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/datasource/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory {
.with_encoding(encoding)
.with_order(cmd.order_exprs.clone())
.with_header(cmd.has_header)
.with_batch_size(state.config().batch_size());
.with_batch_size(state.config().batch_size())
.with_constraints(cmd.constraints.clone());

Ok(Arc::new(StreamTable(Arc::new(config))))
}
Expand Down
13 changes: 7 additions & 6 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{

use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::utils::get_at_indices;
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue, TableReference,
Expand Down Expand Up @@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard(
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let qualifier = TableReference::from(qualifier);
let qualified_fields: Vec<DFField> = schema
.fields_with_qualified(&qualifier)
.into_iter()
.cloned()
.collect();
let qualified_indices = schema.fields_indices_with_qualified(&qualifier);
let projected_func_dependencies = schema
.functional_dependencies()
.project_functional_dependencies(&qualified_indices, qualified_indices.len());
let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?;
if qualified_fields.is_empty() {
return plan_err!("Invalid qualifier {qualifier}");
}
let qualified_schema =
DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?
// We can use the functional dependencies as is, since it only stores indices:
.with_functional_dependencies(schema.functional_dependencies().clone())?;
.with_functional_dependencies(projected_func_dependencies)?;
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
Expand Down
92 changes: 88 additions & 4 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use datafusion_execution::TaskContext;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::{
aggregate::is_order_sensitive,
equivalence::collapse_lex_req,
equivalence::{collapse_lex_req, ProjectionMapping},
expressions::{Column, Max, Min, UnKnownColumn},
physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties,
LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
Expand All @@ -59,7 +59,6 @@ mod topk;
mod topk_stream;

pub use datafusion_expr::AggregateFunction;
use datafusion_physical_expr::equivalence::ProjectionMapping;
pub use datafusion_physical_expr::expressions::create_aggregate_expr;

/// Hash aggregate modes
Expand Down Expand Up @@ -464,7 +463,7 @@ impl AggregateExec {
pub fn try_new(
mode: AggregateMode,
group_by: PhysicalGroupBy,
mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
Expand All @@ -482,6 +481,37 @@ impl AggregateExec {
group_by.expr.len(),
));
let original_schema = Arc::new(original_schema);
AggregateExec::try_new_with_schema(
mode,
group_by,
aggr_expr,
filter_expr,
input,
input_schema,
schema,
original_schema,
)
}

/// Create a new hash aggregate execution plan with the given schema.
/// This constructor isn't part of the public API, it is used internally
/// by Datafusion to enforce schema consistency during when re-creating
/// `AggregateExec`s inside optimization rules. Schema field names of an
/// `AggregateExec` depends on the names of aggregate expressions. Since
/// a rule may re-write aggregate expressions (e.g. reverse them) during
/// initialization, field names may change inadvertently if one re-creates
/// the schema in such cases.
#[allow(clippy::too_many_arguments)]
fn try_new_with_schema(
mode: AggregateMode,
group_by: PhysicalGroupBy,
mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
schema: SchemaRef,
original_schema: SchemaRef,
) -> Result<Self> {
// Reset ordering requirement to `None` if aggregator is not order-sensitive
let mut order_by_expr = aggr_expr
.iter()
Expand Down Expand Up @@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut me = AggregateExec::try_new(
let mut me = AggregateExec::try_new_with_schema(
self.mode,
self.group_by.clone(),
self.aggr_expr.clone(),
self.filter_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
self.schema.clone(),
self.original_schema.clone(),
)?;
me.limit = self.limit;
Ok(Arc::new(me))
Expand Down Expand Up @@ -2162,4 +2194,56 @@ mod tests {
assert_eq!(res, common_requirement);
Ok(())
}

#[test]
fn test_agg_exec_same_schema() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
]));

let col_a = col("a", &schema)?;
let col_b = col("b", &schema)?;
let option_desc = SortOptions {
descending: true,
nulls_first: true,
};
let sort_expr = vec![PhysicalSortExpr {
expr: col_b.clone(),
options: option_desc,
}];
let sort_expr_reverse = reverse_order_bys(&sort_expr);
let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![
Arc::new(FirstValue::new(
col_b.clone(),
"FIRST_VALUE(b)".to_string(),
DataType::Float64,
sort_expr_reverse.clone(),
vec![DataType::Float64],
)),
Arc::new(LastValue::new(
col_b.clone(),
"LAST_VALUE(b)".to_string(),
DataType::Float64,
sort_expr.clone(),
vec![DataType::Float64],
)),
];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates.clone(),
vec![None, None],
blocking_exec.clone(),
schema,
)?);
let new_agg = aggregate_exec
.clone()
.with_new_children(vec![blocking_exec])?;
assert_eq!(new_agg.schema(), aggregate_exec.schema());
Ok(())
}
}
12 changes: 12 additions & 0 deletions datafusion/sqllogictest/test_files/groupby.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4280,3 +4280,15 @@ LIMIT 5
2 0 0
3 0 0
4 0 1


query ITIPTR rowsort
SELECT r.*
FROM sales_global_with_pk as l, sales_global_with_pk as r
LIMIT 5
----
0 GRC 0 2022-01-01T06:00:00 EUR 30
1 FRA 1 2022-01-01T08:00:00 EUR 50
1 FRA 3 2022-01-02T12:00:00 EUR 200
1 TUR 2 2022-01-01T11:30:00 TRY 75
1 TUR 4 2022-01-03T10:00:00 TRY 100

0 comments on commit 6f5230f

Please sign in to comment.