Skip to content

Commit dc7e44d

Browse files
authored
fix: Unconditionally wrap UNION BY NAME input nodes w/ Projection (#15242)
* fix: Remove incorrect predicate to skip input wrapping when rewriting union inputs * chore: Add/update tests * fix: SQL integration tests * test: Add union all by name SLT tests * test: Add problematic union all by name SLT test * chore: styling nits * fix: Correct handling of nullability when field is not present in all inputs * chore: Update fixme comment * fix: handle ordering by order of inputs
1 parent dc073ff commit dc7e44d

File tree

3 files changed

+229
-83
lines changed

3 files changed

+229
-83
lines changed

datafusion/expr/src/logical_plan/plan.rs

+48-41
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Logical plan types
1919
2020
use std::cmp::Ordering;
21-
use std::collections::{BTreeMap, HashMap, HashSet};
21+
use std::collections::{HashMap, HashSet};
2222
use std::fmt::{self, Debug, Display, Formatter};
2323
use std::hash::{Hash, Hasher};
2424
use std::str::FromStr;
@@ -2681,24 +2681,16 @@ impl Union {
26812681
Ok(Union { inputs, schema })
26822682
}
26832683

2684-
/// When constructing a `UNION BY NAME`, we may need to wrap inputs
2684+
/// When constructing a `UNION BY NAME`, we need to wrap inputs
26852685
/// in an additional `Projection` to account for absence of columns
2686-
/// in input schemas.
2686+
/// in input schemas or differing projection orders.
26872687
fn rewrite_inputs_from_schema(
2688-
schema: &DFSchema,
2688+
schema: &Arc<DFSchema>,
26892689
inputs: Vec<Arc<LogicalPlan>>,
26902690
) -> Result<Vec<Arc<LogicalPlan>>> {
26912691
let schema_width = schema.iter().count();
26922692
let mut wrapped_inputs = Vec::with_capacity(inputs.len());
26932693
for input in inputs {
2694-
// If the input plan's schema contains the same number of fields
2695-
// as the derived schema, then it does not to be wrapped in an
2696-
// additional `Projection`.
2697-
if input.schema().iter().count() == schema_width {
2698-
wrapped_inputs.push(input);
2699-
continue;
2700-
}
2701-
27022694
// Any columns that exist within the derived schema but do not exist
27032695
// within an input's schema should be replaced with `NULL` aliased
27042696
// to the appropriate column in the derived schema.
@@ -2713,9 +2705,9 @@ impl Union {
27132705
expr.push(Expr::Literal(ScalarValue::Null).alias(column.name()));
27142706
}
27152707
}
2716-
wrapped_inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new(
2717-
expr, input,
2718-
)?)));
2708+
wrapped_inputs.push(Arc::new(LogicalPlan::Projection(
2709+
Projection::try_new_with_schema(expr, input, Arc::clone(schema))?,
2710+
)));
27192711
}
27202712

27212713
Ok(wrapped_inputs)
@@ -2749,45 +2741,60 @@ impl Union {
27492741
inputs: &[Arc<LogicalPlan>],
27502742
loose_types: bool,
27512743
) -> Result<DFSchemaRef> {
2752-
type FieldData<'a> = (&'a DataType, bool, Vec<&'a HashMap<String, String>>);
2753-
// Prefer `BTreeMap` as it produces items in order by key when iterated over
2754-
let mut cols: BTreeMap<&str, FieldData> = BTreeMap::new();
2744+
type FieldData<'a> =
2745+
(&'a DataType, bool, Vec<&'a HashMap<String, String>>, usize);
2746+
let mut cols: Vec<(&str, FieldData)> = Vec::new();
27552747
for input in inputs.iter() {
27562748
for field in input.schema().fields() {
2757-
match cols.entry(field.name()) {
2758-
std::collections::btree_map::Entry::Occupied(mut occupied) => {
2759-
let (data_type, is_nullable, metadata) = occupied.get_mut();
2760-
if !loose_types && *data_type != field.data_type() {
2761-
return plan_err!(
2762-
"Found different types for field {}",
2763-
field.name()
2764-
);
2765-
}
2766-
2767-
metadata.push(field.metadata());
2768-
// If the field is nullable in any one of the inputs,
2769-
// then the field in the final schema is also nullable.
2770-
*is_nullable |= field.is_nullable();
2749+
if let Some((_, (data_type, is_nullable, metadata, occurrences))) =
2750+
cols.iter_mut().find(|(name, _)| name == field.name())
2751+
{
2752+
if !loose_types && *data_type != field.data_type() {
2753+
return plan_err!(
2754+
"Found different types for field {}",
2755+
field.name()
2756+
);
27712757
}
2772-
std::collections::btree_map::Entry::Vacant(vacant) => {
2773-
vacant.insert((
2758+
2759+
metadata.push(field.metadata());
2760+
// If the field is nullable in any one of the inputs,
2761+
// then the field in the final schema is also nullable.
2762+
*is_nullable |= field.is_nullable();
2763+
*occurrences += 1;
2764+
} else {
2765+
cols.push((
2766+
field.name(),
2767+
(
27742768
field.data_type(),
27752769
field.is_nullable(),
27762770
vec![field.metadata()],
2777-
));
2778-
}
2771+
1,
2772+
),
2773+
));
27792774
}
27802775
}
27812776
}
27822777

27832778
let union_fields = cols
27842779
.into_iter()
2785-
.map(|(name, (data_type, is_nullable, unmerged_metadata))| {
2786-
let mut field = Field::new(name, data_type.clone(), is_nullable);
2787-
field.set_metadata(intersect_maps(unmerged_metadata));
2780+
.map(
2781+
|(name, (data_type, is_nullable, unmerged_metadata, occurrences))| {
2782+
// If the final number of occurrences of the field is less
2783+
// than the number of inputs (i.e. the field is missing from
2784+
// one or more inputs), then it must be treated as nullable.
2785+
let final_is_nullable = if occurrences == inputs.len() {
2786+
is_nullable
2787+
} else {
2788+
true
2789+
};
27882790

2789-
(None, Arc::new(field))
2790-
})
2791+
let mut field =
2792+
Field::new(name, data_type.clone(), final_is_nullable);
2793+
field.set_metadata(intersect_maps(unmerged_metadata));
2794+
2795+
(None, Arc::new(field))
2796+
},
2797+
)
27912798
.collect::<Vec<(Option<TableReference>, _)>>();
27922799

27932800
let union_schema_metadata =

datafusion/sql/tests/sql_integration.rs

+16-11
Original file line numberDiff line numberDiff line change
@@ -1898,11 +1898,12 @@ fn union_by_name_different_columns() {
18981898
let expected = "\
18991899
Distinct:\
19001900
\n Union\
1901-
\n Projection: NULL AS Int64(1), order_id\
1901+
\n Projection: order_id, NULL AS Int64(1)\
19021902
\n Projection: orders.order_id\
19031903
\n TableScan: orders\
1904-
\n Projection: orders.order_id, Int64(1)\
1905-
\n TableScan: orders";
1904+
\n Projection: order_id, Int64(1)\
1905+
\n Projection: orders.order_id, Int64(1)\
1906+
\n TableScan: orders";
19061907
quick_test(sql, expected);
19071908
}
19081909

@@ -1936,22 +1937,26 @@ fn union_all_by_name_different_columns() {
19361937
"SELECT order_id from orders UNION ALL BY NAME SELECT order_id, 1 FROM orders";
19371938
let expected = "\
19381939
Union\
1939-
\n Projection: NULL AS Int64(1), order_id\
1940+
\n Projection: order_id, NULL AS Int64(1)\
19401941
\n Projection: orders.order_id\
19411942
\n TableScan: orders\
1942-
\n Projection: orders.order_id, Int64(1)\
1943-
\n TableScan: orders";
1943+
\n Projection: order_id, Int64(1)\
1944+
\n Projection: orders.order_id, Int64(1)\
1945+
\n TableScan: orders";
19441946
quick_test(sql, expected);
19451947
}
19461948

19471949
#[test]
19481950
fn union_all_by_name_same_column_names() {
19491951
let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id FROM orders";
1950-
let expected = "Union\
1951-
\n Projection: orders.order_id\
1952-
\n TableScan: orders\
1953-
\n Projection: orders.order_id\
1954-
\n TableScan: orders";
1952+
let expected = "\
1953+
Union\
1954+
\n Projection: order_id\
1955+
\n Projection: orders.order_id\
1956+
\n TableScan: orders\
1957+
\n Projection: order_id\
1958+
\n Projection: orders.order_id\
1959+
\n TableScan: orders";
19551960
quick_test(sql, expected);
19561961
}
19571962

0 commit comments

Comments
 (0)