Skip to content

Commit

Permalink
feat: support Map literals in Substrait consumer and producer (apache…
Browse files Browse the repository at this point in the history
…#11547)

* implement Map literals/nulls  conversions in Substrait

* fix name handling for lists/maps containing structs

* add hashing for map scalars

* add a test for creating a map in VALUES

* fix clipppy

* better test

* use MapBuilder in test

* fix hash test

* remove unnecessary type variation checks from maps
  • Loading branch information
Blizzara authored Jul 23, 2024
1 parent 67c6ee2 commit f80dde0
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 27 deletions.
102 changes: 100 additions & 2 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use arrow_buffer::IntervalMonthDayNano;

use crate::cast::{
as_boolean_array, as_fixed_size_list_array, as_generic_binary_array,
as_large_list_array, as_list_array, as_primitive_array, as_string_array,
as_struct_array,
as_large_list_array, as_list_array, as_map_array, as_primitive_array,
as_string_array, as_struct_array,
};
use crate::error::{Result, _internal_err};

Expand Down Expand Up @@ -236,6 +236,40 @@ fn hash_struct_array(
Ok(())
}

fn hash_map_array(
array: &MapArray,
random_state: &RandomState,
hashes_buffer: &mut [u64],
) -> Result<()> {
let nulls = array.nulls();
let offsets = array.offsets();

// Create hashes for each entry in each row
let mut values_hashes = vec![0u64; array.entries().len()];
create_hashes(array.entries().columns(), random_state, &mut values_hashes)?;

// Combine the hashes for entries on each row with each other and previous hash for that row
if let Some(nulls) = nulls {
for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() {
if nulls.is_valid(i) {
let hash = &mut hashes_buffer[i];
for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] {
*hash = combine_hashes(*hash, *values_hash);
}
}
}
} else {
for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() {
let hash = &mut hashes_buffer[i];
for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] {
*hash = combine_hashes(*hash, *values_hash);
}
}
}

Ok(())
}

fn hash_list_array<OffsetSize>(
array: &GenericListArray<OffsetSize>,
random_state: &RandomState,
Expand Down Expand Up @@ -400,6 +434,10 @@ pub fn create_hashes<'a>(
let array = as_large_list_array(array)?;
hash_list_array(array, random_state, hashes_buffer)?;
}
DataType::Map(_, _) => {
let array = as_map_array(array)?;
hash_map_array(array, random_state, hashes_buffer)?;
}
DataType::FixedSizeList(_,_) => {
let array = as_fixed_size_list_array(array)?;
hash_fixed_list_array(array, random_state, hashes_buffer)?;
Expand Down Expand Up @@ -572,6 +610,7 @@ mod tests {
Some(vec![Some(3), None, Some(5)]),
None,
Some(vec![Some(0), Some(1), Some(2)]),
Some(vec![]),
];
let list_array =
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(data)) as ArrayRef;
Expand All @@ -581,6 +620,7 @@ mod tests {
assert_eq!(hashes[0], hashes[5]);
assert_eq!(hashes[1], hashes[4]);
assert_eq!(hashes[2], hashes[3]);
assert_eq!(hashes[1], hashes[6]); // null vs empty list
}

#[test]
Expand Down Expand Up @@ -692,6 +732,64 @@ mod tests {
assert_eq!(hashes[0], hashes[1]);
}

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
fn create_hashes_for_map_arrays() {
let mut builder =
MapBuilder::new(None, StringBuilder::new(), Int32Builder::new());
// Row 0
builder.keys().append_value("key1");
builder.keys().append_value("key2");
builder.values().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
// Row 1
builder.keys().append_value("key1");
builder.keys().append_value("key2");
builder.values().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
// Row 2
builder.keys().append_value("key1");
builder.keys().append_value("key2");
builder.values().append_value(1);
builder.values().append_value(3);
builder.append(true).unwrap();
// Row 3
builder.keys().append_value("key1");
builder.keys().append_value("key3");
builder.values().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
// Row 4
builder.keys().append_value("key1");
builder.values().append_value(1);
builder.append(true).unwrap();
// Row 5
builder.keys().append_value("key1");
builder.values().append_null();
builder.append(true).unwrap();
// Row 6
builder.append(true).unwrap();
// Row 7
builder.keys().append_value("key1");
builder.values().append_value(1);
builder.append(false).unwrap();

let array = Arc::new(builder.finish()) as ArrayRef;

let random_state = RandomState::with_seeds(0, 0, 0, 0);
let mut hashes = vec![0; array.len()];
create_hashes(&[array], &random_state, &mut hashes).unwrap();
assert_eq!(hashes[0], hashes[1]); // same value
assert_ne!(hashes[0], hashes[2]); // different value
assert_ne!(hashes[0], hashes[3]); // different key
assert_ne!(hashes[0], hashes[4]); // missing an entry
assert_ne!(hashes[4], hashes[5]); // filled vs null value
assert_eq!(hashes[6], hashes[7]); // empty vs null map
}

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,7 @@ impl ScalarValue {
}
DataType::List(_)
| DataType::LargeList(_)
| DataType::Map(_, _)
| DataType::Struct(_)
| DataType::Union(_, _) => {
let arrays = scalars.map(|s| s.to_array()).collect::<Result<Vec<_>>>()?;
Expand Down Expand Up @@ -1838,7 +1839,6 @@ impl ScalarValue {
| DataType::Time32(TimeUnit::Nanosecond)
| DataType::Time64(TimeUnit::Second)
| DataType::Time64(TimeUnit::Millisecond)
| DataType::Map(_, _)
| DataType::RunEndEncoded(_, _)
| DataType::ListView(_)
| DataType::LargeListView(_) => {
Expand Down
8 changes: 8 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,11 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), a
{POST: 41, HEAD: 33, PATCH: 30}
{POST: 41, HEAD: 33, PATCH: 30}
{POST: 41, HEAD: 33, PATCH: 30}


query ?
VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1]))
----
{a: 1}
{b: 2}
{c: 3, a: 1}
143 changes: 125 additions & 18 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer};
use async_recursion::async_recursion;
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::array::{GenericListArray, MapArray};
use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
Expand Down Expand Up @@ -51,6 +51,7 @@ use crate::variation_const::{
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF,
};
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
Expand Down Expand Up @@ -1449,21 +1450,14 @@ fn from_substrait_type(
from_substrait_type(value_type, extensions, dfs_names, name_idx)?,
true,
));
match map.type_variation_reference {
DEFAULT_CONTAINER_TYPE_VARIATION_REF => {
Ok(DataType::Map(
Arc::new(Field::new_struct(
"entries",
[key_field, value_field],
false, // The inner map field is always non-nullable (Arrow #1697),
)),
false,
))
}
v => not_impl_err!(
"Unsupported Substrait type variation {v} of type {s_kind:?}"
)?,
}
Ok(DataType::Map(
Arc::new(Field::new_struct(
"entries",
[key_field, value_field],
false, // The inner map field is always non-nullable (Arrow #1697),
)),
false, // whether keys are sorted
))
}
r#type::Kind::Decimal(d) => match d.type_variation_reference {
DECIMAL_128_TYPE_VARIATION_REF => {
Expand Down Expand Up @@ -1743,11 +1737,23 @@ fn from_substrait_literal(
)
}
Some(LiteralType::List(l)) => {
// Each element should start the name index from the same value, then we increase it
// once at the end
let mut element_name_idx = *name_idx;
let elements = l
.values
.iter()
.map(|el| from_substrait_literal(el, extensions, dfs_names, name_idx))
.map(|el| {
element_name_idx = *name_idx;
from_substrait_literal(
el,
extensions,
dfs_names,
&mut element_name_idx,
)
})
.collect::<Result<Vec<_>>>()?;
*name_idx = element_name_idx;
if elements.is_empty() {
return substrait_err!(
"Empty list must be encoded as EmptyList literal type, not List"
Expand Down Expand Up @@ -1785,6 +1791,84 @@ fn from_substrait_literal(
}
}
}
Some(LiteralType::Map(m)) => {
// Each entry should start the name index from the same value, then we increase it
// once at the end
let mut entry_name_idx = *name_idx;
let entries = m
.key_values
.iter()
.map(|kv| {
entry_name_idx = *name_idx;
let key_sv = from_substrait_literal(
kv.key.as_ref().unwrap(),
extensions,
dfs_names,
&mut entry_name_idx,
)?;
let value_sv = from_substrait_literal(
kv.value.as_ref().unwrap(),
extensions,
dfs_names,
&mut entry_name_idx,
)?;
ScalarStructBuilder::new()
.with_scalar(Field::new("key", key_sv.data_type(), false), key_sv)
.with_scalar(
Field::new("value", value_sv.data_type(), true),
value_sv,
)
.build()
})
.collect::<Result<Vec<_>>>()?;
*name_idx = entry_name_idx;

if entries.is_empty() {
return substrait_err!(
"Empty map must be encoded as EmptyMap literal type, not Map"
);
}

ScalarValue::Map(Arc::new(MapArray::new(
Arc::new(Field::new("entries", entries[0].data_type(), false)),
OffsetBuffer::new(vec![0, entries.len() as i32].into()),
ScalarValue::iter_to_array(entries)?.as_struct().to_owned(),
None,
false,
)))
}
Some(LiteralType::EmptyMap(m)) => {
let key = match &m.key {
Some(k) => Ok(k),
_ => plan_err!("Missing key type for empty map"),
}?;
let value = match &m.value {
Some(v) => Ok(v),
_ => plan_err!("Missing value type for empty map"),
}?;
let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?;
let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?;

// new_empty_array on a MapType creates a too empty array
// We want it to contain an empty struct array to align with an empty MapBuilder one
let entries = Field::new_struct(
"entries",
vec![
Field::new("key", key_type, false),
Field::new("value", value_type, true),
],
false,
);
let struct_array =
new_empty_array(entries.data_type()).as_struct().to_owned();
ScalarValue::Map(Arc::new(MapArray::new(
Arc::new(entries),
OffsetBuffer::new(vec![0, 0].into()),
struct_array,
None,
false,
)))
}
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
Expand Down Expand Up @@ -2013,6 +2097,29 @@ fn from_substrait_null(
),
}
}
r#type::Kind::Map(map) => {
let key_type = map.key.as_ref().ok_or_else(|| {
substrait_datafusion_err!("Map type must have key type")
})?;
let value_type = map.value.as_ref().ok_or_else(|| {
substrait_datafusion_err!("Map type must have value type")
})?;

let key_type =
from_substrait_type(key_type, extensions, dfs_names, name_idx)?;
let value_type =
from_substrait_type(value_type, extensions, dfs_names, name_idx)?;
let entries_field = Arc::new(Field::new_struct(
"entries",
vec![
Field::new("key", key_type, false),
Field::new("value", value_type, true),
],
false,
));

DataType::Map(entries_field, false /* keys sorted */).try_into()
}
r#type::Kind::Struct(s) => {
let fields =
from_substrait_struct_type(s, extensions, dfs_names, name_idx)?;
Expand Down
Loading

0 comments on commit f80dde0

Please sign in to comment.